{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp models.layers"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Layers"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">Helper functions used to build PyTorch timeseries models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from torch.jit import TracerWarning\n",
    "import warnings\n",
    "from torch.nn.utils import weight_norm, spectral_norm\n",
    "from torch.nn.init import normal_\n",
    "from fastcore.basics import snake2camel\n",
    "from fastcore.test import test_eq\n",
    "from fastai.layers import *\n",
    "from fastai.losses import *\n",
    "from tsai.imports import *\n",
    "from tsai.utils import *\n",
    "\n",
    "warnings.filterwarnings(\"ignore\", category=TracerWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def test_module_to_torchscript(\n",
    "    m:torch.nn.Module, # The PyTorch module to be tested.\n",
    "    inputs:Tensor, # A tensor or tuple of tensors representing the inputs to the model.\n",
    "    trace:bool=True, #  If `True`, attempts to trace the model. Defaults to `True`.\n",
    "    script:bool=True, # If `True`, attempts to script the model. Defaults to `True`.\n",
    "    serialize:bool=True, # If `True`, saves and loads the traced/scripted module to ensure it can be serialized. Defaults to `True`.\n",
    "    verbose:bool=True, # If `True`, prints detailed information about the tracing and scripting process. Defaults to `True`.\n",
    "):\n",
    "    \"Tests if a PyTorch module can be correctly traced or scripted and serialized\"\n",
    "    \n",
    "    m = m.eval()\n",
    "    m_name = m.__class__.__name__\n",
    "\n",
    "    # Ensure inputs are in a tuple or list format\n",
    "    inp_is_tuple = isinstance(inputs, (tuple, list))\n",
    "\n",
    "    # Get the model's output\n",
    "    output = m(*inputs) if inp_is_tuple else m(inputs)\n",
    "    output_shapes = output.shape if not isinstance(output, (tuple, list)) else [o.shape for o in output]\n",
    "    if verbose:\n",
    "        print(f\"output.shape: {output_shapes}\")\n",
    "\n",
    "    # Try tracing the model\n",
    "    if trace:\n",
    "        if verbose:\n",
    "            print(\"Tracing...\")\n",
    "        try:\n",
    "            traced_m = torch.jit.trace(m, inputs)\n",
    "            if serialize:\n",
    "                file_path = Path(f\"test_traced_{m_name}.pt\")\n",
    "                torch.jit.save(traced_m, file_path)\n",
    "                traced_mod = torch.jit.load(file_path)\n",
    "                file_path.unlink()\n",
    "            traced_output = traced_m(*inputs) if inp_is_tuple else traced_m(inputs)\n",
    "            torch.testing.assert_close(traced_output, output)\n",
    "            if verbose:\n",
    "                print(f\"...{m_name} has been successfully traced 😃\\n\")\n",
    "            return True\n",
    "        except Exception as e:\n",
    "            if verbose:\n",
    "                print(f\"{m_name} cannot be traced 😔\")\n",
    "                print(e)\n",
    "                print(\"\\n\")\n",
    "\n",
    "    # Try scripting the model\n",
    "    if script:\n",
    "        if verbose:\n",
    "            print(\"Scripting...\")\n",
    "        try:\n",
    "            scripted_m = torch.jit.script(m)\n",
    "            if serialize:\n",
    "                file_path = Path(f\"test_scripted_{m_name}.pt\")\n",
    "                torch.jit.save(scripted_m, file_path)\n",
    "                scripted_mod = torch.jit.load(file_path)\n",
    "                file_path.unlink()\n",
    "            scripted_output = scripted_m(*inputs) if inp_is_tuple else scripted_m(inputs)\n",
    "            torch.testing.assert_close(scripted_output, output)\n",
    "            if verbose:\n",
    "                print(f\"...{m_name} has been successfully scripted 😃\\n\")\n",
    "            return True\n",
    "        except Exception as e:\n",
    "            if verbose:\n",
    "                print(f\"{m_name} cannot be scripted 😔\")\n",
    "                print(e)\n",
    "\n",
    "    return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "output.shape: torch.Size([3, 2])\n",
      "Tracing...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "...Linear has been successfully traced 😃\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = nn.Linear(10, 2)\n",
    "inp = torch.randn(3, 10)\n",
    "test_module_to_torchscript(m, inp, trace=True, script=True, serialize=True, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def init_lin_zero(m):\n",
    "    if isinstance(m, (nn.Linear)): \n",
    "        if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)\n",
    "        nn.init.constant_(m.weight, 0)\n",
    "    for l in m.children(): init_lin_zero(l)\n",
    "        \n",
    "lin_zero_init = init_lin_zero"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export    \n",
    "class SwishBeta(Module):\n",
    "    def __multiinit__(self, beta=1.): \n",
    "        self.sigmoid = torch.sigmoid\n",
    "        self.beta = nn.Parameter(torch.Tensor(1).fill_(beta))\n",
    "    def forward(self, x): return x.mul(self.sigmoid(x*self.beta))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class SmeLU(nn.Module):\n",
    "    \"Smooth ReLU activation function based on https://arxiv.org/pdf/2202.06499.pdf\"\n",
    "\n",
    "    def __init__(self, \n",
    "        beta: float = 2. # Beta value\n",
    "        ) -> None:\n",
    "        super().__init__()\n",
    "        self.beta = abs(beta)\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        return torch.where(torch.abs(x) <= self.beta, ((x + self.beta) ** 2) / (4. * self.beta), F.relu(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class Chomp1d(nn.Module):\n",
    "    def __init__(self, chomp_size):\n",
    "        super(Chomp1d, self).__init__()\n",
    "        self.chomp_size = chomp_size\n",
    "        \n",
    "    def forward(self, x):\n",
    "        return x[:, :, :-self.chomp_size].contiguous()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def same_padding1d(seq_len, ks, stride=1, dilation=1):\n",
    "    \"Same padding formula as used in Tensorflow\"\n",
    "    p = (seq_len - 1) * stride + (ks - 1) * dilation + 1 - seq_len\n",
    "    return p // 2, p - p // 2\n",
    "\n",
    "\n",
    "class Pad1d(nn.ConstantPad1d):\n",
    "    def __init__(self, padding, value=0.):\n",
    "        super().__init__(padding, value)\n",
    "\n",
    "        \n",
    "# @delegates(nn.Conv1d.__init__)\n",
    "class SameConv1d(Module):\n",
    "    \"Conv1d with padding='same'\"\n",
    "    def __init__(self, ni, nf, ks=3, stride=1, dilation=1, **kwargs):\n",
    "        self.ks, self.stride, self.dilation = ks, stride, dilation\n",
    "        self.conv1d_same = nn.Conv1d(ni, nf, ks, stride=stride, dilation=dilation, **kwargs)\n",
    "        self.weight = self.conv1d_same.weight\n",
    "        self.bias = self.conv1d_same.bias\n",
    "        self.pad = Pad1d\n",
    "\n",
    "    def forward(self, x):\n",
    "        self.padding = same_padding1d(x.shape[-1], self.ks, dilation=self.dilation) #stride=self.stride not used in padding calculation!\n",
    "        return self.conv1d_same(self.pad(self.padding)(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def same_padding2d(H, W, ks, stride=(1, 1), dilation=(1, 1)):\n",
    "    \"Same padding formula as used in Tensorflow\"\n",
    "    if isinstance(ks, Integral): ks = (ks, ks)\n",
    "    if ks[0] == 1:  p_h = 0\n",
    "    else:  p_h = (H - 1) * stride[0] + (ks[0] - 1) * dilation[0] + 1 - H\n",
    "    if ks[1] == 1:  p_w = 0\n",
    "    else:  p_w = (W - 1) * stride[1] + (ks[1] - 1) * dilation[1] + 1 - W\n",
    "    return (p_w // 2, p_w - p_w // 2, p_h // 2, p_h - p_h // 2)\n",
    "\n",
    "\n",
    "class Pad2d(nn.ConstantPad2d):\n",
    "    def __init__(self, padding, value=0.):\n",
    "        super().__init__(padding, value)\n",
    "\n",
    "\n",
    "# @delegates(nn.Conv2d.__init__)\n",
    "class Conv2dSame(Module):\n",
    "    \"Conv2d with padding='same'\"\n",
    "    def __init__(self, ni, nf, ks=(3, 3), stride=(1, 1), dilation=(1, 1), **kwargs):\n",
    "        if isinstance(ks, Integral): ks = (ks, ks)\n",
    "        if isinstance(stride, Integral): stride = (stride, stride)\n",
    "        if isinstance(dilation, Integral): dilation = (dilation, dilation)\n",
    "        self.ks, self.stride, self.dilation = ks, stride, dilation\n",
    "        self.conv2d_same = nn.Conv2d(ni, nf, ks, stride=stride, dilation=dilation, **kwargs)\n",
    "        self.weight = self.conv2d_same.weight\n",
    "        self.bias = self.conv2d_same.bias\n",
    "        self.pad = Pad2d\n",
    "\n",
    "    def forward(self, x):\n",
    "        self.padding = same_padding2d(x.shape[-2], x.shape[-1], self.ks, dilation=self.dilation) #stride=self.stride not used in padding calculation!\n",
    "        return self.conv2d_same(self.pad(self.padding)(x))\n",
    "    \n",
    "    \n",
    "# @delegates(nn.Conv2d.__init__)\n",
    "def Conv2d(ni, nf, kernel_size=None, ks=None, stride=1, padding='same', dilation=1, init='auto', bias_std=0.01, **kwargs):\n",
    "    \"conv1d layer with padding='same', 'valid', or any integer (defaults to 'same')\"\n",
    "    assert not (kernel_size and ks), 'use kernel_size or ks but not both simultaneously'\n",
    "    assert kernel_size is not None or ks is not None, 'you need to pass a ks'\n",
    "    kernel_size = kernel_size or ks\n",
    "    if padding == 'same': \n",
    "        conv = Conv2dSame(ni, nf, kernel_size, stride=stride, dilation=dilation, **kwargs)\n",
    "    elif padding == 'valid': conv = nn.Conv2d(ni, nf, kernel_size, stride=stride, padding=0, dilation=dilation, **kwargs)\n",
    "    else: conv = nn.Conv2d(ni, nf, kernel_size, stride=stride, padding=padding, dilation=dilation, **kwargs)\n",
    "    init_linear(conv, None, init=init, bias_std=bias_std)\n",
    "    return conv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "c_in = 3\n",
    "c_out = 5\n",
    "h = 16\n",
    "w = 20\n",
    "t = torch.rand(bs, c_in, h, w)\n",
    "test_eq(Conv2dSame(c_in, c_out, ks=3, stride=1, dilation=1, bias=False)(t).shape, (bs, c_out, h, w))\n",
    "test_eq(Conv2dSame(c_in, c_out, ks=(3, 1), stride=1, dilation=1, bias=False)(t).shape, (bs, c_out, h, w))\n",
    "test_eq(Conv2dSame(c_in, c_out, ks=3, stride=(1, 1), dilation=(2, 2), bias=False)(t).shape, (bs, c_out, h, w))\n",
    "test_eq(Conv2dSame(c_in, c_out, ks=3, stride=(2, 2), dilation=(1, 1), bias=False)(t).shape, (bs, c_out, h//2, w//2))\n",
    "test_eq(Conv2dSame(c_in, c_out, ks=3, stride=(2, 2), dilation=(2, 2), bias=False)(t).shape, (bs, c_out, h//2, w//2))\n",
    "test_eq(Conv2d(c_in, c_out, ks=3, padding='same', stride=1, dilation=1, bias=False)(t).shape, (bs, c_out, h, w))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class CausalConv1d(torch.nn.Conv1d):\n",
    "    def __init__(self, ni, nf, ks, stride=1, dilation=1, groups=1, bias=True):\n",
    "        super(CausalConv1d, self).__init__(ni, nf, kernel_size=ks, stride=stride, padding=0, dilation=dilation, groups=groups, bias=bias)\n",
    "        self.__padding = (ks - 1) * dilation\n",
    "    def forward(self, input):\n",
    "        return super(CausalConv1d, self).forward(F.pad(input, (self.__padding, 0)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "# @delegates(nn.Conv1d.__init__)\n",
    "def Conv1d(ni, nf, kernel_size=None, ks=None, stride=1, padding='same', dilation=1, init='auto', bias_std=0.01, **kwargs):\n",
    "    \"conv1d layer with padding='same', 'causal', 'valid', or any integer (defaults to 'same')\"\n",
    "    assert not (kernel_size and ks), 'use kernel_size or ks but not both simultaneously'\n",
    "    assert kernel_size is not None or ks is not None, 'you need to pass a ks'\n",
    "    kernel_size = kernel_size or ks\n",
    "    if padding == 'same': \n",
    "        if kernel_size%2==1: \n",
    "            conv = nn.Conv1d(ni, nf, kernel_size, stride=stride, padding=kernel_size//2 * dilation, dilation=dilation, **kwargs)\n",
    "        else:\n",
    "            conv = SameConv1d(ni, nf, kernel_size, stride=stride, dilation=dilation, **kwargs)\n",
    "    elif padding == 'causal': conv = CausalConv1d(ni, nf, kernel_size, stride=stride, dilation=dilation, **kwargs)\n",
    "    elif padding == 'valid': conv = nn.Conv1d(ni, nf, kernel_size, stride=stride, padding=0, dilation=dilation, **kwargs)\n",
    "    else: conv = nn.Conv1d(ni, nf, kernel_size, stride=stride, padding=padding, dilation=dilation, **kwargs)\n",
    "    init_linear(conv, None, init=init, bias_std=bias_std)\n",
    "    return conv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "c_in = 3\n",
    "c_out = 5\n",
    "seq_len = 512\n",
    "t = torch.rand(bs, c_in, seq_len)\n",
    "dilation = 1\n",
    "test_eq(CausalConv1d(c_in, c_out, ks=3, dilation=dilation)(t).shape, Conv1d(c_in, c_out, ks=3, padding=\"same\", dilation=dilation)(t).shape)\n",
    "dilation = 2\n",
    "test_eq(CausalConv1d(c_in, c_out, ks=3, dilation=dilation)(t).shape, Conv1d(c_in, c_out, ks=3, padding=\"same\", dilation=dilation)(t).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "ni = 3\n",
    "nf = 5\n",
    "seq_len = 6\n",
    "ks = 3\n",
    "t = torch.rand(bs, c_in, seq_len)\n",
    "test_eq(Conv1d(ni, nf, ks, padding=0)(t).shape, (bs, c_out, seq_len - (2 * (ks//2))))\n",
    "test_eq(Conv1d(ni, nf, ks, padding='valid')(t).shape, (bs, c_out, seq_len - (2 * (ks//2))))\n",
    "test_eq(Conv1d(ni, nf, ks, padding='same')(t).shape, (bs, c_out, seq_len))\n",
    "test_eq(Conv1d(ni, nf, ks, padding='causal')(t).shape, (bs, c_out, seq_len))\n",
    "test_error('use kernel_size or ks but not both simultaneously', Conv1d, ni, nf, kernel_size=3, ks=3)\n",
    "test_error('you need to pass a ks', Conv1d, ni, nf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv1d(3, 5, kernel_size=(3,), stride=(1,), padding=(1,))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conv = Conv1d(ni, nf, ks, padding='same')\n",
    "init_linear(conv, None, init='auto', bias_std=.01)\n",
    "conv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CausalConv1d(3, 5, kernel_size=(3,), stride=(1,))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conv = Conv1d(ni, nf, ks, padding='causal')\n",
    "init_linear(conv, None, init='auto', bias_std=.01)\n",
    "conv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv1d(3, 5, kernel_size=(3,), stride=(1,))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conv = Conv1d(ni, nf, ks, padding='valid')\n",
    "init_linear(conv, None, init='auto', bias_std=.01)\n",
    "weight_norm(conv)\n",
    "conv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv1d(3, 5, kernel_size=(3,), stride=(1,))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conv = Conv1d(ni, nf, ks, padding=0)\n",
    "init_linear(conv, None, init='auto', bias_std=.01)\n",
    "weight_norm(conv)\n",
    "conv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class SeparableConv1d(Module):\n",
    "    def __init__(self, ni, nf, ks, stride=1, padding='same', dilation=1, bias=True, bias_std=0.01):\n",
    "        self.depthwise_conv = Conv1d(ni, ni, ks, stride=stride, padding=padding, dilation=dilation, groups=ni, bias=bias)\n",
    "        self.pointwise_conv = nn.Conv1d(ni, nf, 1, stride=1, padding=0, dilation=1, groups=1, bias=bias)\n",
    "        if bias:\n",
    "            if bias_std != 0: \n",
    "                normal_(self.depthwise_conv.bias, 0, bias_std)\n",
    "                normal_(self.pointwise_conv.bias, 0, bias_std)\n",
    "            else: \n",
    "                self.depthwise_conv.bias.data.zero_()\n",
    "                self.pointwise_conv.bias.data.zero_()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.depthwise_conv(x)\n",
    "        x = self.pointwise_conv(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 64\n",
    "c_in = 6\n",
    "c_out = 5\n",
    "seq_len = 512\n",
    "t = torch.rand(bs, c_in, seq_len)\n",
    "test_eq(SeparableConv1d(c_in, c_out, 3)(t).shape, (bs, c_out, seq_len))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class AddCoords1d(Module):\n",
    "    \"\"\"Add coordinates to ease position identification without modifying mean and std\"\"\"\n",
    "    def forward(self, x):\n",
    "        bs, _, seq_len = x.shape\n",
    "        cc = torch.linspace(-1,1,x.shape[-1], device=x.device).repeat(bs, 1, 1)\n",
    "        cc = (cc - cc.mean()) / cc.std()\n",
    "        x = torch.cat([x, cc], dim=1)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "c_in = 3\n",
    "c_out = 5\n",
    "seq_len = 50\n",
    "\n",
    "t = torch.rand(bs, c_in, seq_len)\n",
    "t = (t - t.mean()) / t.std()\n",
    "test_eq(AddCoords1d()(t).shape, (bs, c_in + 1, seq_len))\n",
    "new_t = AddCoords1d()(t)\n",
    "test_close(new_t.mean(),0, 1e-2)\n",
    "test_close(new_t.std(), 1, 1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class ConvBlock(nn.Sequential):\n",
    "    \"Create a sequence of conv1d (`ni` to `nf`), activation (if `act_cls`) and `norm_type` layers.\"\n",
    "    def __init__(self, ni, nf, kernel_size=None, ks=3, stride=1, padding='same', bias=None, bias_std=0.01, norm='Batch', zero_norm=False, bn_1st=True,\n",
    "                 act=nn.ReLU, act_kwargs={}, init='auto', dropout=0., xtra=None, coord=False, separable=False,  **kwargs):\n",
    "        kernel_size = kernel_size or ks\n",
    "        ndim = 1\n",
    "        layers = [AddCoords1d()] if coord else []\n",
    "        norm_type = getattr(NormType,f\"{snake2camel(norm)}{'Zero' if zero_norm else ''}\") if norm is not None else None\n",
    "        bn = norm_type in (NormType.Batch, NormType.BatchZero)\n",
    "        inn = norm_type in (NormType.Instance, NormType.InstanceZero)\n",
    "        if bias is None: bias = not (bn or inn)\n",
    "        if separable: conv = SeparableConv1d(ni + coord, nf, ks=kernel_size, bias=bias, stride=stride, padding=padding, **kwargs)\n",
    "        else: conv = Conv1d(ni + coord, nf, ks=kernel_size, bias=bias, stride=stride, padding=padding, **kwargs)\n",
    "        act = None if act is None else act(**act_kwargs)\n",
    "        if not separable: init_linear(conv, act, init=init, bias_std=bias_std)\n",
    "        if   norm_type==NormType.Weight:   conv = weight_norm(conv)\n",
    "        elif norm_type==NormType.Spectral: conv = spectral_norm(conv)\n",
    "        layers += [conv]\n",
    "        act_bn = []        \n",
    "        if act is not None: act_bn.append(act)\n",
    "        if bn: act_bn.append(BatchNorm(nf, norm_type=norm_type, ndim=ndim))\n",
    "        if inn: act_bn.append(InstanceNorm(nf, norm_type=norm_type, ndim=ndim))\n",
    "        if bn_1st: act_bn.reverse()\n",
    "        if dropout: layers += [nn.Dropout(dropout)]\n",
    "        layers += act_bn\n",
    "        if xtra: layers.append(xtra)\n",
    "        super().__init__(*layers)     \n",
    "                            \n",
    "Conv = named_partial('Conv', ConvBlock, norm=None, act=None)\n",
    "ConvBN = named_partial('ConvBN', ConvBlock, norm='Batch', act=None)\n",
    "CoordConv = named_partial('CoordConv', ConvBlock, norm=None, act=None, coord=True)\n",
    "SepConv = named_partial('SepConv', ConvBlock, norm=None, act=None, separable=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class ResBlock1dPlus(Module):\n",
    "    \"Resnet block from `ni` to `nh` with `stride`\"\n",
    "#     @delegates(ConvLayer.__init__)\n",
    "    def __init__(self, expansion, ni, nf, coord=False, stride=1, groups=1, reduction=None, nh1=None, nh2=None, dw=False, g2=1,\n",
    "                 sa=False, sym=False, norm='Batch', zero_norm=True, act_cls=defaults.activation, ks=3,\n",
    "                 pool=AvgPool, pool_first=True, **kwargs):\n",
    "        if nh2 is None: nh2 = nf\n",
    "        if nh1 is None: nh1 = nh2\n",
    "        nf,ni = nf*expansion,ni*expansion\n",
    "        k0 = dict(norm=norm, zero_norm=False, act=act_cls, **kwargs)\n",
    "        k1 = dict(norm=norm, zero_norm=zero_norm, act=None, **kwargs)\n",
    "        convpath  = [ConvBlock(ni,  nh2, ks, coord=coord, stride=stride, groups=ni if dw else groups, **k0),\n",
    "                     ConvBlock(nh2,  nf, ks, coord=coord, groups=g2, **k1)\n",
    "        ] if expansion == 1 else [\n",
    "                     ConvBlock(ni,  nh1, 1, coord=coord, **k0),\n",
    "                     ConvBlock(nh1, nh2, ks, coord=coord, stride=stride, groups=nh1 if dw else groups, **k0),\n",
    "                     ConvBlock(nh2,  nf, 1, coord=coord, groups=g2, **k1)]\n",
    "        if reduction: convpath.append(SEModule(nf, reduction=reduction, act_cls=act_cls))\n",
    "        if sa: convpath.append(SimpleSelfAttention(nf,ks=1,sym=sym))\n",
    "        self.convpath = nn.Sequential(*convpath)\n",
    "        idpath = []\n",
    "        if ni!=nf: idpath.append(ConvBlock(ni, nf, 1, coord=coord, act=None, **kwargs))\n",
    "        if stride!=1: idpath.insert((1,0)[pool_first], pool(stride, ndim=1, ceil_mode=True))\n",
    "        self.idpath = nn.Sequential(*idpath)\n",
    "        self.act = defaults.activation(inplace=True) if act_cls is defaults.activation else act_cls()\n",
    "\n",
    "    def forward(self, x): return self.act(self.convpath(x) + self.idpath(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def SEModule1d(ni, reduction=16, act=nn.ReLU, act_kwargs={}):\n",
    "    \"Squeeze and excitation module for 1d\"\n",
    "    nf = math.ceil(ni//reduction/8)*8\n",
    "    assert nf != 0, 'nf cannot be 0'\n",
    "    return SequentialEx(nn.AdaptiveAvgPool1d(1), \n",
    "                        ConvBlock(ni, nf, ks=1, norm=None, act=act, act_kwargs=act_kwargs),\n",
    "                        ConvBlock(nf, ni, ks=1, norm=None, act=nn.Sigmoid), ProdLayer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.rand(8, 32, 12)\n",
    "test_eq(SEModule1d(t.shape[1], 16, act=nn.ReLU, act_kwargs={})(t).shape, t.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def Norm(nf, ndim=1, norm='Batch', zero_norm=False, init=True, **kwargs):\n",
    "    \"Norm layer with `nf` features and `ndim` with auto init.\"\n",
    "    assert 1 <= ndim <= 3\n",
    "    nl = getattr(nn, f\"{snake2camel(norm)}Norm{ndim}d\")(nf, **kwargs)\n",
    "    if nl.affine and init:\n",
    "        nl.bias.data.fill_(1e-3)\n",
    "        nl.weight.data.fill_(0. if zero_norm else 1.)\n",
    "    return nl\n",
    "\n",
    "BN1d = partial(Norm, ndim=1, norm='Batch')\n",
    "IN1d = partial(Norm, ndim=1, norm='Instance')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "ni = 3\n",
    "nf = 5\n",
    "sl = 4\n",
    "ks = 5\n",
    "\n",
    "t = torch.rand(bs, ni, sl)\n",
    "test_eq(ConvBlock(ni, nf, ks)(t).shape, (bs, nf, sl))\n",
    "test_eq(ConvBlock(ni, nf, ks, padding='causal')(t).shape, (bs, nf, sl))\n",
    "test_eq(ConvBlock(ni, nf, ks, coord=True)(t).shape, (bs, nf, sl))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(BN1d(ni)(t).shape, (bs, ni, sl))\n",
    "test_eq(BN1d(ni).weight.data.mean().item(), 1.)\n",
    "test_eq(BN1d(ni, zero_norm=True).weight.data.mean().item(), 0.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ConvBlock(\n",
       "  (0): AddCoords1d()\n",
       "  (1): Conv1d(4, 5, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
       "  (2): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (3): Swish()\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_eq(ConvBlock(ni, nf, ks, norm='batch', zero_norm=True)[1].weight.data.unique().item(), 0)\n",
    "test_ne(ConvBlock(ni, nf, ks, norm='batch', zero_norm=False)[1].weight.data.unique().item(), 0)\n",
    "test_eq(ConvBlock(ni, nf, ks, bias=False)[0].bias, None)\n",
    "ConvBlock(ni, nf, ks, act=Swish, coord=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class LinLnDrop(nn.Sequential):\n",
    "    \"Module grouping `LayerNorm1d`, `Dropout` and `Linear` layers\"\n",
    "    def __init__(self, n_in, n_out, ln=True, p=0., act=None, lin_first=False):\n",
    "        layers = [nn.LayerNorm(n_out if lin_first else n_in)] if ln else []\n",
    "        if p != 0: layers.append(nn.Dropout(p))\n",
    "        lin = [nn.Linear(n_in, n_out, bias=not ln)]\n",
    "        if act is not None: lin.append(act)\n",
    "        layers = lin+layers if lin_first else layers+lin\n",
    "        super().__init__(*layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LinLnDrop(\n",
       "  (0): LayerNorm((2,), eps=1e-05, elementwise_affine=True)\n",
       "  (1): Dropout(p=0.5, inplace=False)\n",
       "  (2): Linear(in_features=2, out_features=3, bias=False)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "LinLnDrop(2, 3, p=.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class LambdaPlus(Module):\n",
    "    def __init__(self, func, *args, **kwargs): self.func,self.args,self.kwargs=func,args,kwargs\n",
    "    def forward(self, x): return self.func(x, *self.args, **self.kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class Squeeze(Module):\n",
    "    def __init__(self, dim=-1): self.dim = dim\n",
    "    def forward(self, x): return x.squeeze(dim=self.dim)\n",
    "    def __repr__(self): return f'{self.__class__.__name__}(dim={self.dim})'\n",
    "\n",
    "\n",
    "class Unsqueeze(Module):\n",
    "    def __init__(self, dim=-1): self.dim = dim\n",
    "    def forward(self, x): return x.unsqueeze(dim=self.dim)\n",
    "    def __repr__(self): return f'{self.__class__.__name__}(dim={self.dim})'\n",
    "\n",
    "\n",
    "class Add(Module):\n",
    "    def forward(self, x, y): return x.add(y)\n",
    "    def __repr__(self): return f'{self.__class__.__name__}'\n",
    "\n",
    "\n",
    "class Concat(Module):\n",
    "    def __init__(self, dim=1): self.dim = dim\n",
    "    def forward(self, *x): return torch.cat(*x, dim=self.dim)\n",
    "    def __repr__(self): return f'{self.__class__.__name__}(dim={self.dim})'\n",
    "\n",
    "\n",
    "class Unfold(Module):\n",
    "    def __init__(self, dim, size, step=1): self.dim, self.size, self.step =  dim, size, step\n",
    "    def forward(self, x:Tensor) -> Tensor: return x.unfold(dimension=self.dim, size=self.size, step=self.step)\n",
    "    def __repr__(self): return f\"{self.__class__.__name__}(dim={self.dim}, size={self.size}, step={self.step})\"\n",
    "    \n",
    "    \n",
    "class Permute(Module):\n",
    "    def __init__(self, *dims): self.dims = dims\n",
    "    def forward(self, x:Tensor) -> Tensor: return x.permute(self.dims)\n",
    "    def __repr__(self): return f\"{self.__class__.__name__}(dims={', '.join([str(d) for d in self.dims])})\"\n",
    "    \n",
    "    \n",
    "class Transpose(Module):\n",
    "    def __init__(self, *dims, contiguous=False): self.dims, self.contiguous = dims, contiguous\n",
    "    def forward(self, x): \n",
    "        if self.contiguous: return x.transpose(*self.dims).contiguous()\n",
    "        else: return x.transpose(*self.dims)\n",
    "    def __repr__(self): \n",
    "        if self.contiguous: return f\"{self.__class__.__name__}(dims={', '.join([str(d) for d in self.dims])}).contiguous()\"\n",
    "        else: return f\"{self.__class__.__name__}({', '.join([str(d) for d in self.dims])})\"\n",
    "    \n",
    "    \n",
    "class View(Module):\n",
    "    def __init__(self, *shape): self.shape = shape\n",
    "    def forward(self, x): \n",
    "        return x.view(x.shape[0], -1).contiguous() if not self.shape else x.view(-1).contiguous() if self.shape == (-1,) else \\\n",
    "            x.view(x.shape[0], *self.shape).contiguous()\n",
    "    def __repr__(self): return f\"{self.__class__.__name__}({', '.join(['bs'] + [str(s) for s in self.shape])})\"\n",
    "    \n",
    "    \n",
    "class Reshape(Module):\n",
    "    def __init__(self, *shape): self.shape = shape\n",
    "    def forward(self, x):\n",
    "        return x.reshape(x.shape[0], -1) if not self.shape else x.reshape(-1) if self.shape == (-1,) else x.reshape(x.shape[0], *self.shape)\n",
    "    def __repr__(self): return f\"{self.__class__.__name__}({', '.join(['bs'] + [str(s) for s in self.shape])})\"\n",
    "    \n",
    "    \n",
    "class Max(Module):\n",
    "    def __init__(self, dim=None, keepdim=False): self.dim, self.keepdim = dim, keepdim\n",
    "    def forward(self, x): return x.max(self.dim, keepdim=self.keepdim)[0]\n",
    "    def __repr__(self): return f'{self.__class__.__name__}(dim={self.dim}, keepdim={self.keepdim})'\n",
    "\n",
    "    \n",
    "class LastStep(Module):\n",
    "    def forward(self, x): return x[..., -1]\n",
    "    def __repr__(self): return f'{self.__class__.__name__}()'\n",
    "    \n",
    "    \n",
    "class SoftMax(Module):\n",
    "    \"SoftMax layer\"\n",
    "    def __init__(self, dim=-1):\n",
    "        self.dim = dim\n",
    "    def forward(self, x):\n",
    "        return F.softmax(x, dim=self.dim)\n",
    "    def __repr__(self): return f'{self.__class__.__name__}(dim={self.dim})' \n",
    "    \n",
    "\n",
    "class Clamp(Module):\n",
    "    def __init__(self, min=None, max=None):\n",
    "        self.min, self.max = min, max\n",
    "    def forward(self, x):\n",
    "        return x.clamp(min=self.min, max=self.max)\n",
    "    def __repr__(self): return f'{self.__class__.__name__}(min={self.min}, max={self.max})' \n",
    "    \n",
    "    \n",
    "class Clip(Module):\n",
    "    def __init__(self, min=None, max=None):\n",
    "        self.min, self.max = min, max\n",
    "        \n",
    "    def forward(self, x):\n",
    "        if self.min is not None:\n",
    "            x = torch.maximum(x, self.min)\n",
    "        if self.max is not None:\n",
    "            x = torch.minimum(x, self.max)\n",
    "        return x\n",
    "    def __repr__(self): return f'{self.__class__.__name__}()' \n",
    "    \n",
    "    \n",
    "class ReZero(Module):\n",
    "    def __init__(self, module):\n",
    "        self.module = module\n",
    "        self.alpha = nn.Parameter(torch.zeros(1))\n",
    "    def forward(self, x):\n",
    "        return x + self.alpha * self.module(x)\n",
    "    \n",
    "    \n",
    "Noop = nn.Sequential()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Transpose(1, 2),\n",
       " Permute(dims=0, 2, 1),\n",
       " View(bs, -1, 2, 10),\n",
       " Transpose(dims=1, 2).contiguous(),\n",
       " Reshape(bs, -1, 2, 10),\n",
       " Sequential())"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 2\n",
    "nf = 5\n",
    "sl = 4\n",
    "\n",
    "t = torch.rand(bs, nf, sl)\n",
    "test_eq(Permute(0,2,1)(t).shape, (bs, sl, nf))\n",
    "test_eq(Max(1)(t).shape, (bs, sl))\n",
    "test_eq(Transpose(1,2)(t).shape, (bs, sl, nf))\n",
    "test_eq(Transpose(1,2, contiguous=True)(t).shape, (bs, sl, nf))\n",
    "test_eq(View(-1, 2, 10)(t).shape, (bs, 1, 2, 10))\n",
    "test_eq(Reshape(-1, 2, 10)(t).shape, (bs, 1, 2, 10))\n",
    "test_eq(Reshape()(t).shape, (2, 20))\n",
    "test_eq(Reshape(-1)(t).shape, (40,))\n",
    "Transpose(1,2), Permute(0,2,1), View(-1, 2, 10), Transpose(1,2, contiguous=True), Reshape(-1, 2, 10), Noop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class DropPath(nn.Module):\n",
    "    \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n",
    "\n",
    "    It's similar to Dropout but it drops individual connections instead of nodes.\n",
    "    Original code in https://github.com/rwightman/pytorch-image-models (timm library)\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, p=None):\n",
    "        super().__init__()\n",
    "        self.p = p\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.p == 0. or not self.training: return x\n",
    "        keep_prob = 1 - self.p\n",
    "        shape = (x.shape[0],) + (1,) * (x.ndim - 1)\n",
    "        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)\n",
    "        random_tensor.floor_()\n",
    "        output = x.div(keep_prob) * random_tensor\n",
    "#         output = x.div(random_tensor.mean()) * random_tensor # divide by the actual mean to mantain the input mean?\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.ones(100,2,3)\n",
    "test_eq(DropPath(0.)(t), t)\n",
    "assert DropPath(0.5)(t).max() >= 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class Sharpen(Module):\n",
    "    \"This is used to increase confidence in predictions - MixMatch paper\"\n",
    "    def __init__(self, T=.5): self.T = T\n",
    "    def forward(self, x):\n",
    "        x = x**(1. / self.T)\n",
    "        return x / x.sum(dim=1, keepdims=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABJs0lEQVR4nO3deVxU1fsH8M+AgCugoSCK+547JpJbKkmu7ZGamusvU1MpTTTXSizL5ZuaabmUuZdWLpRiahblnkuGuW+AKwyigjDn98cj4CgoAzNzZ/m8X695nXPv3OWZKzCP5557jk4ppUBERESkERetAyAiIiLnxmSEiIiINMVkhIiIiDTFZISIiIg0xWSEiIiINMVkhIiIiDTFZISIiIg0xWSEiIiINFVI6wDywmAw4OLFiyhRogR0Op3W4RAREVEeKKWQnJwMf39/uLjk3v5hF8nIxYsXERAQoHUYRERElA/nzp1D+fLlc33fLpKREiVKAJAP4+npqXE0RERElBd6vR4BAQFZ3+O5sYtkJPPWjKenJ5MRIiIiO/OoLhbswEpERESaYjJCREREmmIyQkRERJpiMkJERESaYjJCREREmmIyQkRERJpiMkJERESaYjJCREREmmIyQkRERJoyORnZsWMHunTpAn9/f+h0Oqxbt+6R+2zbtg2NGzeGh4cHqlWrhsWLF+cjVCIiInJEJicjKSkpaNCgAebMmZOn7U+dOoVOnTqhTZs2OHDgAIYPH47+/fvj559/NjlYIiIicjwmz03ToUMHdOjQIc/bz5s3D5UrV8ann34KAKhduzZ27tyJGTNmIDQ01NTTExERkYOx+ER5MTExCAkJMVoXGhqK4cOH57pPamoqUlNTs5b1er2lwiMiIjKbO3eA5GQgKUleN28CqalAWpqUt2/LKz0dUEpeBkPe66Zs+9B6WjpUkh4qKQmGRD1Ukh7hC+uhUkNvTa6bxZOR+Ph4+Pr6Gq3z9fWFXq/HrVu3UKRIkQf2iYyMxKRJkywdGhERUZ4pBVy9Chw7BsTGSpn5iosDbtyQhMM+FAJQ6u5LdN9z2HGTkfyIiIhAeHh41rJer0dAQICGERERkbNISQGOH8856bh+PW/HKFIE8PICihUD3N0BD4/sV+HCQKFCgE4HuLhIme962m3orl6Fy/Ur0F29Ct3VK3BJToQu/Q5cYIAOCjqonOvubnDx8oTO2ws6by/4V61i2Qv7EBZPRvz8/JCQkGC0LiEhAZ6enjm2igCAh4cHPDw8LB0aERE5uWvXgKgoYMMG4MQJ4Px54MKFh+9ToQJQo4bxKyAAKFEi++XmZoFgMzKkaeb4ceDHH4GNG4FDh3LetkgRoFIlwN8f8POTV+nSQNmyQK1aQNWqQKlSktXYAIsnI8HBwdi4caPRus2bNyM4ONjSpyYiInpAQgKweDHw5ZfyvZ4TH58HE44aNeQ7vGhRKwarFPD77xLsunXSEeV+1aoBDRsCjRoBDRpIoBUrSpOMnTA5Gblx4waO3/Ovd+rUKRw4cAClSpVChQoVEBERgQsXLuDrr78GALzxxhuYPXs2Ro0ahb59+2Lr1q1YtWoVNmzYYL5PQURE9BCJicCqVcD33wNbtkgjQ6Y6dYBnnwWaNpWGg+rVpdFAU7duAePGAcuXAxcvGr9XpgzQti3QsiXQtStQvrw2MZqRycnInj170KZNm6zlzL4dvXv3xuLFixEXF4ezZ89mvV+5cmVs2LABI0aMwKxZs1C+fHl8+eWXfKyXiIgsLjUVmDULmDrVuL9HUBAQFgZ06yZ3MGyGXi8Z07hxcs8IkI4m7doB4eFAq1YWugekLZ1SSmkdxKPo9Xp4eXkhKSkJnp6eWodDREQ27PZtYPNm4JtvgJ9/lu93QO5mvPaaJCA1amgbo5EjR4AVK4CffpJ6erqs9/ICxo4Fhg6VXq92KK/f3zb5NA0REZGpDAbg22+BYcOMW0HKlgUGDJDGhkK28q2XlgasXAlMnw4cOGD8XvXqQJ8+wODBgJP8B9xW/lmIiIjyxWAAVq+W2zExMbKuVCmgVy+5FdO0qTwGazNWrQJGjMjuC+LmBjzzjATbsqU8ruNkmIwQEZHdOn9eulKsXi3Lrq6yPGmSPN1qM9LT5TbM3LnSgxaQJpvBg4FBg2ygx6y2mIwQEZFdMRiAefOkT8iff8o6nQ54/XVg1CgZRsNmXL8OzJ4NfPFF9gAmOh3wf/8HzJwpnVOJyQgREdmXiROB99/PXm7cGPjwQ7nTYVOSk+Xpl8OHZbl0aaB/f2DgQBmQjLIwGSEiIrtx4gRwdxJ4dOgAzJ9vg8NsKAV89RXwwQfAmTPy7PAnnwAvvcSWkFwwGSEiIrtw6hTQpo3MhFu+PLB2rY19tysFLF0qTTcnT8q60qUl0GbNNA3N1tlS/2IiIqIcnTwJBAYC587JHY5Nm2wsEQHk9kuvXhJssWLSk/bkSSYiecCWESIismmnT8sApNevy9DtGzbYWJeLq1eB3r0lMEAGKhs9GiheXNu47AhbRoiIyGZt3SoJyOnTwGOPAWvW2FgiEhMjnVQzE5HRo6WvCBMRk7BlhIiIbNKlS0DHjjK/TLlyQHQ0ULOm1lHdlZEBjB8PREZKX5GSJWUM+sBArSOzS2wZISIimzR9uiQi7u7A/v02lIgAMsjJlCmSiLRrJwEyEck3towQEZHNOXwYmDFD6lOmyEMpNuPSJRk1FQDq1ZPetA44k641sWWEiIhsyuHDQPPmMpdckyYyjYtNWb9epgb29AQ2bmQiYgZMRoiIyGZERQFt2wJ6vfQTWbfOxia5U0pm5AOAIUNscMQ1+2RL/8REROTEFi+WUVUvX5a7H3/9JQmJTfnoI+DgQXlaZuhQraNxGExGiIhIcydPAn36SP3FF20wEbl5Exg2DIiIkOVu3WSYdzILdmAlIiLNtW0rpY+PTOtSpIi28RhZtkxm2b1xQ5affRaYO1fbmBwMW0aIiEhT69bJfHIAMHky4OWlaTjZlAKmTgV69MhOREaMAL7/HijE/8ubE68mERFpasECKUuWlAYIm3Dnjsw1s3ixLPv4yGM+vr6ahuWo2DJCRESaOXJEnqABZGR1m3hy5uxZIDhYEhEXF7klc+kSExELYssIERFpIj1dGh8MBqBLFxsZYfXoUZkMB5BbMWvWSB8RsigmI0REpImBA4E//pCnZGfP1joaALt3A6Gh2cvR0TIJHlkckxEiIrIqpeQp2UWLZHnpUqBCBW1jAgC8+SZw/bq0iBw8CNSurXVEToPJCBERWdXKlcBnnwE6HfDJJzZwF8RgAP73P2DPHlneuZOJiJXZQlchIiJyEkoB77wj9XffBcLDtY0HgDzOkzkBTufOQFCQtvE4ISYjRERkNStWABcuSN0mHuNduTJ7Bt5OnYDVq7WNx0nxNg0REVlFSopM7QIA1asDlSppGg5w/DjQvbs017RpA3z3HeDhoXFQzonJCBERWVx6utye+ftvoEQJYO1ajQO6fRto3Vr6i9SsCWzeDLi6ahyU8+JtGiIisiilgLFjgXnzZHn6dODxx7WNCYsWARcvSv3bb5mIaIzJCBERWdTo0cDHH0u9f3+gXz9t48GOHcBbb0n9o4+AwEBt4yEmI0REZDlbtmQnIp9+Kg+u6HQaBnTiBNCxo9w36tgx+yka0hSTESIisogdO4Cnn5Z616428BivUkDPntKT1s9PMiM3N42DIoDJCBERWUBiItCnj9TLlAG++krTcMTy5TIbHyCz8/n7axsPZWEyQkREZjdyJHDypIysvmUL4OOjcUCHDwMDBki9Xz+gQQNt4yEjTEaIiMisjh8HFi+W+vvvA/XqaRqO3J4ZPBi4eRNo3NhGZuWjezEZISIisxo7VvqHVqgADB+udTSQx3h37JD63LlA4cLaxkMPYDJCRERm8/ffwKpVUl+61Aa+99PSgPfek/qIEZx3xkYxGSEiIrPZtUvK4GCgZUttYwEgLSJxcfI88cSJWkdDuWAyQkREZpGWBrz5ptTbtdM2FgBARgbw8stS79YN8PTUNh7KFZMRIiIyi5kzpa+Ih0d2UqKZuDjprJqYKMuTJmkaDj0ckxEiIiqwjIzskVaHDwfKltU0HOCll4CDB6U+fTpQrZq28dBDcdZeIiIqsN9+A65elfr772sbC9atA/74Q+rbtwOtWmkaDj0aW0aIiKhAMjKAN96Qet++Go+wvnw58OKLUq9alYmInWAyQkREBTJxIhAbC5QsKZPhaSIxEYiIALp3BwwGoHp1YOtWjYIhU/E2DRER5ds//wBTp0r9gw8Ab28Ngjh+HAgNlfHnASAsDPj6a8DdXYNgKD/YMkJERPk2apQ8QRMUBAwapEEAt24BTzwhiUiJEsCMGcCyZUxE7AxbRoiIKN9iY6V86y0ZV8zqnn8++/Hd3buBmjU1CIIKii0jRESUL5s3yx0SQIb0sLp//wV+/lnqy5czEbFjTEaIiChfPvxQysqVgRo1rHzya9ekbwggM/K9+qqVAyBzYjJCREQmi4qSITwAYNo0wMXa3yazZsmgZl5eQHS0lU9O5sZkhIiITPbZZ1K2bAm88IIGAaxZIyVHV3UITEaIiMgkycnAli1Snz1bg46rS5fKM8VubhplQmRuTEaIiMgk69bJDL3VqgH16mkQwMyZUrZpo9HAJmRuTEaIiMgkixdL2bOnBq0i334L7N0r9enTrXxyshQmI0RElGdxcdmjrD//vJVPPnEi8NprUm/UCHj8cSsHQJbCZISIiPJs4kQp69Wz8i2a06eBSZOk3q0bsH69FU9OlpavZGTOnDmoVKkSChcujKCgIOzateuh28+cORM1a9ZEkSJFEBAQgBEjRuD27dv5CpiIiLRx44aMtA4AI0ZY+eSTJ0tZo4Z0YPX3t3IAZEkmJyMrV65EeHg4JkyYgH379qFBgwYIDQ3FpUuXctx+2bJlGD16NCZMmICjR4/iq6++wsqVKzFmzJgCB09ERNazZo0kJNWqAa+/bsUTp6cDixZJfeBADQY1IUsz+V90+vTpGDBgAPr06YM6depg3rx5KFq0KBYuXJjj9n/88QeaN2+O7t27o1KlSmjfvj26dev2yNYUIiKyLT/8IGWvXlbuuLp5c3a9WzcrnpisxaRkJC0tDXv37kVISEj2AVxcEBISgpiYmBz3efLJJ7F3796s5OPkyZPYuHEjOnbsmOt5UlNTodfrjV5ERKSdtDTg11+lHhpq5ZPv3Cllq1a8PeOgTJq198qVK8jIyICvr6/Rel9fX/z777857tO9e3dcuXIFLVq0gFIK6enpeOONNx56myYyMhKTMjsqERGR5saOBZKSgDJlgMBAK5749m0ZWQ3IfpKGHI7Fb7xt27YNU6ZMwdy5c7Fv3z58//332LBhA95///1c94mIiEBSUlLW69y5c5YOk4iIHiKz4+ro0YCrqxVPvHQpoNcDxYoBD2lRJ/tmUsuIj48PXF1dkZCQYLQ+ISEBfn5+Oe4zbtw49OzZE/379wcA1KtXDykpKRg4cCDGjh0Llxw6Inl4eMDDw8OU0IiIyEKuXQMuXpR6375WPvmePVIOGQKUK2flk5O1mNQy4u7ujsDAQETfM0OiwWBAdHQ0goODc9zn5s2bDyQcrnfTaqWUqfESEZGVrVsnZc2aMkmu1RgMwMaNUrfqvSGyNpNaRgAgPDwcvXv3RpMmTdC0aVPMnDkTKSkp6NOnDwCgV69eKFeuHCIjIwEAXbp0wfTp09GoUSMEBQXh+PHjGDduHLp06ZKVlBARkW06fz57TBGrd9k4fhzIvE3PWzQOzeRkJCwsDJcvX8b48eMRHx+Phg0bIioqKqtT69mzZ41aQt577z3odDq89957uHDhAkqXLo0uXbrgww8/NN+nICIii1i9WrpsVK0KjBplxRMrBbz5ptSDgqTPCDksnbKDeyV6vR5eXl5ISkqCp6en1uEQETmNp54Ctm8Hpk4F3n3Xiif+6Sega1epb9jAlhE7ldfvbw5jR0REObpzB8gcn7JLFyue+L//gLu3/vH660xEnACTESIiytGyZcCtWzK2SK1aVjzxJ58AV68Cnp7ARx9Z8cSkFSYjRET0gL17s+efefFFK04Ho5ScHAA+/VQyIXJ4TEaIiOgB8+dn163WOJGWBgwbJsmIhwfQoYOVTkxaYzJCRERGbt0CliyR+qJFQIkSVjrxV18Bn30m9WnTOMiZE2EyQkRERpYuBVJTgVKlgN69rXjizDHnn38eGDrUiicmrTEZISIiIwsXSvnuu4BOZ6WTJiZmz847ebKVTkq2gskIERFlOXoU+PNPqXfubMUTf/ONlOXLAzVqWPHEZAuYjBARUZatW7PrtWtb6aR6vfQRAYDhwwF3dyudmGwFkxEiIspy/ryUQ4ZY8RZNv34yB03ZskCvXlY6KdkSJiNERJQl8xZNQICVTrhtG7BmjdRXrwZKl7bSicmWMBkhIiIAwD//SG4AAMHBVjjhlSvASy9JPTQUaN7cCiclW8RkhIiIAADR0VIGBwMtW1rhhMuWybDvxYsDX39thROSrWIyQkREAIAzZ6S0SqvImTPAxIlSf/ttDvvu5JiMEBERAOD0aSkrVLDwiZQCevQArl+XGfiGD7fwCcnWMRkhIiIkJwObN0u9QQMLn2zWLOD33wE3N+C77wBvbwufkGwdkxEiIsK338pwHz4+Fu4v8ttvwMiRUu/TB6hTx4InI3vBZISIyMkpBXz+udSHDgVcXS10opMngVdfBdLTgbZtgblzLXQisjdMRoiInNzEicDBg/JQi8Xmp9u9G3jySeDiRbkt8803Fsx6yN4wGSEicmLp6cDUqVIfMwYoWdICJ4mLk0QkIQGoXh3Ytw/w97fAicheFdI6ACIi0s7Jk0BaGlCokMzSaxFffilZj04HxMQAjz1moRORvWLLCBGRE8sc6Kx5c8DFEt8IiYnA+PFSf/ddJiKUIyYjRERObN06KZ9+2kInuLcTytixFjoJ2TsmI0RETiouDvjlF6l36mSBE6xaJUO+A8C0adJDligHTEaIiJzUtGlSNm4MNGxo5oPr9TKOiMEAdO0qQ74T5YLJCBGRk9q1S8rQUAsc/IUXgJs3ZZTVZcuk8ypRLpiMEBE5ofR0YP9+qffsacYDGwxywMyesV9/DRQrZsYTkCNiMkJE5IROnpSGiyJFgJo1zXjgQYOApUulPnCgjLhK9AhMRoiInNCvv0pZs6YZH+ndvRuYP1/qkZHAF1+Y6cDk6JiMEBE5oTVrpDTbpHgZGdJRFZBRVi02gho5IiYjRERORilgzx6p9+5tpoN+/TUQHy/1xYvZYZVMwmSEiMjJrF8vA6MWLgzUq2eGA6anA598IvVGjWQeGiITMBkhInIyy5dL2asX4O5uhgNu3Qr8849kN998Y4YDkrNhMkJE5GQyR1197jkzHXDmTCm7dAEef9xMByVnwmSEiMiJJCcDV69KvXnzAh4sI0NGVt20SZafeaaAByRnVUjrAIiIyHp++03Kxx4DPD0LcKBbt4CXXwY2bJDlPn2A7t0LHB85J7aMEBE5iYsXs0db7dGjgAd7773sRGTWLGDhQukzQpQPTEaIiJzEmjXAtWtSHz++AAc6fx6YPl3qn3wCDB1a4NjIuTEZISJyEtu2STl8uNymybePPpLS1xd46y2OKUIFxmSEiMgJXLuWfVelwBPjHTok5ZQpMisvUQExGSEicgL//QekpQFeXjIuWYGcPCll7doFjosIYDJCROQU4uKkrFWrgHdVbt2SPiMAULlygeMiApiMEBE5hYsXpSxbtoAHmjZNJrfx95c+I0RmwGSEiMgJnDkjpb9/AQ4SGwt8+KHUP/2UHVfJbJiMEBE5gb17paxSJZ8HMBiAgQOl40mHDkBYmNliI2IyQkTk4FJSgL/+knqzZvk8yC+/ADt2AEWLAnPnslWEzIrJCBGRgxs2DLhxAyhdGggMzOdBNm+W8sUXgUqVzBUaEQAmI0REDi0xEfjqK6l/9VU+R2y/cyf7IJyVlyyAyQgRkQNbu1ZKPz+gc+d8HuSXX4CkJLk107u32WIjysRkhIjIgcXGStm+fQG6eSxZIuVbb0lWQ2RmTEaIiBzY/v1S5nvU1dOngdWrpV7gceSJcsZkhIjIgR08KGVQUD4PkNkqUr8+0LixWWIiuh+TESIiB3XuHBAfL/WaNfN5kMxWkX79+DgvWQyTESIiBzVzppQtWgClSuXjAPv3A0eOSP3ZZ80VFtEDmIwQETmoX36RctCgfOycmgp06yb1kBCgQgWzxUV0PyYjREQO6NAh4PBhubPy1FP5OEBkpDyK4+UFLF7MWzRkUflKRubMmYNKlSqhcOHCCAoKwq5dux66fWJiIgYPHoyyZcvCw8MDNWrUwMaNG/MVMBERPdrKlVJ26JCPyfH++guYMkXq8+cD5cqZNTai+xUydYeVK1ciPDwc8+bNQ1BQEGbOnInQ0FDExsaiTJkyD2yflpaGp59+GmXKlMGaNWtQrlw5nDlzBt7e3uaIn4iIcpD5FE2HDvnYsWNHGXW1a1fg5ZfNHhvR/XRKKWXKDkFBQXjiiScwe/ZsAIDBYEBAQACGDh2K0aNHP7D9vHnzMG3aNPz7779wc3PLV5B6vR5eXl5ISkqCp6dnvo5BROQslAJc7rZ779gBtGxpws4vvCDDttasKS0kXl4WiZGcQ16/v026TZOWloa9e/ciJCQk+wAuLggJCUFMTEyO+/z4448IDg7G4MGD4evri7p162LKlCnIyMjI9TypqanQ6/VGLyIiypvffsuumzQ0yJUrwIYNUl+8mIkIWY1JyciVK1eQkZEBX19fo/W+vr6Iz3yY/T4nT57EmjVrkJGRgY0bN2LcuHH49NNP8cEHH+R6nsjISHh5eWW9AgICTAmTiMip/fqrlJUqAcWKmbDjlClAWpr0Ecn3KGlEprP40zQGgwFlypTB/PnzERgYiLCwMIwdOxbz5s3LdZ+IiAgkJSVlvc6dO2fpMImIHMaWLVKOGGHCTitWADNmSH3GDD49Q1ZlUgdWHx8fuLq6IiEhwWh9QkIC/HKZPKls2bJwc3ODq6tr1rratWsjPj4eaWlpcHd3f2AfDw8PeHh4mBIaERFBGjZ27pR606Z53Gn/fqB/f6n3789Oq2R1JrWMuLu7IzAwENHR0VnrDAYDoqOjERwcnOM+zZs3x/Hjx2EwGLLWHTt2DGXLls0xESEiovzLfKTX0zOPd1oMBuCll4CUFKB1a2DuXIvGR5QTk2/ThIeHY8GCBViyZAmOHj2KQYMGISUlBX369AEA9OrVCxEREVnbDxo0CNeuXcOwYcNw7NgxbNiwAVOmTMHgwYPN9ymIiAgA8PnnUvbvn8c7Ldu2ASdPSv3bb4F8PvVIVBAmjzMSFhaGy5cvY/z48YiPj0fDhg0RFRWV1an17NmzcHHJznECAgLw888/Y8SIEahfvz7KlSuHYcOG4d133zXfpyAiIhw5AmQ+2NirVx53WrxYyi5dOLgZacbkcUa0wHFGiIgebexYeSCmc2fgp5/ysMP33wMvviiDkmzbZuKAJESPZpFxRoiIyHbFxkrZrl0eNk5Ozu602qkTExHSFJMRIiIHoBTw999Sr1kzDzvMng1cvw4UKiR9RYg0xGSEiMgBbN4MHD8ug5w9+WQedshMQCZMAEqUsGhsRI/CZISIyAFkDnTWrVseRnFfsEB6u7q4AK+9ZvHYiB6FyQgRkQPYtk3KR7aK7NgBDBwo9f79Zcx4Io0xGSEisnPx8cDu3TKuSIcOj9h4xQopdTpg+nSLx0aUF0xGiIjs3P79UtauDeQyM0e27dulXLbMxFn0iCyHyQgRkZ3L7Iv6+OOP2PDcOeCff6QeEmLRmIhMwWSEiMjOZXZebd/+IRulpQHdu0u9VSvAx8ficRHlFZMRIiI7lpEBXL4s9c6dH7Jh167Z0/lmdmAlshFMRoiI7Fhioky8CwClSuWy0d69wM8/S/3jj7NbSIhsBJMRIiI7duWKlJ6egLt7LhuNHi2lnx/wzjt5nM6XyHqYjBAR2bFz56S8O3H6gzIysgch+fRTJiJkk5iMEBHZsbVrpWzcOJcNDh0C0tOl2eSVV6wWF5EpmIwQEdmpP/4APv9c6n375rBBRgYwZozUu3aVSfGIbBCTESIiO7Vxo8zW+/TT8nrAwoXApk1SHzLEqrERmYLJCBGRnTp8WMrOnXPpCvLNN1KOHAm0bm21uIhMxWSEiMhOHTsmZa1aOby5fj3w229Sf/NNq8VElB9MRoiI7ND168DRo1KvXv2+N69eBXr3lvoLL3BmXrJ5TEaIiOzQ2LFSVqmSQ64xeTJw7ZrUBw+2ZlhE+cJkhIjIzsTHZz9FM3Lkff1F4uOB+fOlvnIl0Lat1eMjMhWTESIiO/Pll1JWrQq88cY9bxgMkp3cvg0EBgIvv6xJfESm4kPnRER2Zv16KY0ekElLk9l4//pLlocO5WirZDfYMkJEZEf+/VfyDVdXYMqUe96YOjU7EZkyBejVS5P4iPKDLSNERHZkyRIpO3S4Zz6a+Hhg0iSpT5oERERoEhtRfrFlhIjITmRkAIsWSf311++uTE8HBg6U/iK1agHvvadVeET5xmSEiMhOHDkCJCQAxYvLVDPIyJDJ7376STZ4803AhX/Wyf7wNg0RkZ34/XcpmzUD3NwArPspe9reBQtymS2PyPYxGSEisgMZGdn9RVq2hNyeyezBOmgQ0L+/ZrERFRTb84iI7MDixdkPy3RpmSiDme3eDRQunD0cK5GdYjJCRGTjlMoeVPWNN4BGv3wkk+C5uMgoq+XKaRsgUQHxNg0RkY3bsQPYtQvw8AAmjUkFgu7er/nf/+72ZCWyb2wZISKycZMnS9m9O1Bm/BtAXBxQtiz7iZDDYDJCRGTDLl0Ctm6V+siKq6TzCCD3azw8NIuLyJyYjBAR2bCdO6WsUeEWak8Mk4VWrYBx47QLisjMmIwQEdmwuDgp6ybeHWSkYUNgwwZOgkcOhckIEZENO39eSj99rFSWLJEhWIkcCJMRIiIbdvSolDVwDOjcGahfX9uAiCyAj/YSEdmo06eBn6MMAFzQyPUgsHCV1iERWQRbRoiIbNSmTcDtVBc8gV1o2dYdKF1a65CILILJCBGRjTq4WvqJPIHd0I0do3E0RJbDZISIyBb9+CN+/tUNABDc2Qdo3VrjgIgsh8kIEZENOj/pK5xCFQBAi0+e0zYYIgtjMkJEZGs++ADL9tUEAJQocgcVa3CkVXJsfJqGiMiWnD4Nw/iJmImzAIAXw9w4vhk5PLaMEBHZkn79cFWVRBz8AXDUd3IOTEaIiGzF/v3A1q2Ihx8AwMsLqFJF45iIrIDJCBGRrZg/HwDwXcAIAEDjxloGQ2Q9TEaIiGxBQgKwcCEAYFnqiwCA/v21DIjIepiMEBHZgjFjgLQ03KjbDP9d8gIAtGuncUxEVsJkhIhIa4mJwLJlAIBv2i0GAPj5cfR3ch5MRoiItJSQALRoAdy+DVSqhG1xNQAAffsCLvwLTU6CP+pERFpRCujUCThyRBYXLcaOHTKoyNNPaxkYkXUxGSEi0sqJE8DevVKfOxc/6VsjPh5wdweCgrQNjciaOAIrEZFWvvlGyurVgUGD8M3LstizJ1CkiHZhEVkbW0aIiLSQlgZ8+aXUIyKweTOwZo0s9u6tXVhEWshXMjJnzhxUqlQJhQsXRlBQEHbt2pWn/VasWAGdTofnnnsuP6clInIcR44AFy8CRYsCr7yCpUtldcOGQMuWmkZGZHUmJyMrV65EeHg4JkyYgH379qFBgwYIDQ3FpUuXHrrf6dOn8c4776Alf8uIiIAtW6Rs3hwoVgxxcbL48svahUSkFZOTkenTp2PAgAHo06cP6tSpg3nz5qFo0aJYeHfkwJxkZGSgR48emDRpEqpwogUicnYZGdn9Re62FGf+fy4wUJuQiLRkUjKSlpaGvXv3IiQkJPsALi4ICQlBTExMrvtNnjwZZcqUQb9+/fJ0ntTUVOj1eqMXEZHDmD0bOHQI8PQEXnkFSgHnz8tbZcpoGxqRFkxKRq5cuYKMjAz4+voarff19UV8fHyO++zcuRNfffUVFixYkOfzREZGwsvLK+sVEBBgSphERLbLYAC++ELqAwcCPj74/nvg6lWgeHGgWjVtwyPSgkWfpklOTkbPnj2xYMEC+Pj45Hm/iIgIJCUlZb3OnTtnwSiJiKxo1Srg6FGpj5DZeaOiZPHll4ESJTSKi0hDJo0z4uPjA1dXVyQkJBitT0hIgJ+f3wPbnzhxAqdPn0aXLl2y1hkMBjlxoUKIjY1F1apVH9jPw8MDHh4epoRGRGT7bt0CunWT+pgxgL8/AODsWVnVurVGcRFpzKSWEXd3dwQGBiI6OjprncFgQHR0NIKDgx/YvlatWjh06BAOHDiQ9eratSvatGmDAwcO8PYLETmXOXOy6337ZlWPHZOyYkUrx0NkI0wegTU8PBy9e/dGkyZN0LRpU8ycORMpKSno06cPAKBXr14oV64cIiMjUbhwYdStW9dof29vbwB4YD0RkUNLTgY++kjqQ4cCd1uFz50DTp+WSfEaN9YuPCItmZyMhIWF4fLlyxg/fjzi4+PRsGFDREVFZXVqPXv2LFw41SQRUTaDAXjySeDKFcDNDZg0KeutJUukfOIJebiGyBnplFJK6yAeRa/Xw8vLC0lJSfDkbysR2ZsffwSefVbqy5cDr76a9VaLFsDvv8sDNgMHahQfkYXk9fubTRhERJaUmAi8/bbUR40ySkRSUoDMIZratrV+aES2gskIEZGl3L4NPP88cPy4PDkzapTR2/HxcgenaFGOL0LOjckIEZGldOkCbNsGeHgA330HPPaY0duZoyTcN44kkdNhMkJEZCm7d0sZHg40a/bA2wcOSMkh4MnZMRkhIrKEpUuBpCSpjx2b4yZffill06ZWionIRjEZISIyt4wMYNw4qT//PFCs2AObnD8P/P231IcPt15oRLaIyQgRkbn984+MZObmBnz9dY6bLFkinVdbtwaqVLFueES2hskIEZG5bd8uZePGMhVvDlatkvLu4NVETo3JCBGROWVkyAhmAPDMMzluohTw339Sb97cSnER2TAmI0RE5rRkCXD4sEw206lTjptcuSIT+AIA5wslYjJCRGQ+t24BERFSHzJEJpzJwcmTUvr7yxAkRM6OyQgRkblERQGXLgGFCgETJ+a62bFjUtaoYZ2wiGwdkxEiInOZOVPKHj2AkiVz3SyzvwiTESLBZISIyBxOnwZ27JD6a6/lull6OrBmjdTr1rV8WET2gMkIEVFBJScDXbtKvWFD4Kmnct10yxbg6FHA3R145RWrREdk85iMEBEV1CefAIcOAT4+wLp10mckF999J+Vrr3GCPKJMTEaIiApCKeDzz6U+ejRQseJDN88c7OyFFywcF5EdYTJCRFQQf/0FXL4s9Zdeeuimv/8O6PWATge0aGGF2IjsBJMRIqL8Ugp45x2p9+r1yFaRlSulDAoCvLwsHBuRHWEyQkSUX7NmSXOHiwswcuQjN//jDynDwy0cF5GdYTJCRJQfJ09mj7Y6atQjn9O9dUtGiQeA+vUtHBuRnWEyQkRkKqWAN98Ebt8GnnwSmDLlkbv88guQmgqUKgVUr26FGInsCJMRIiJTZGQA3bsDP/8sg4UsXCg9Uh9h+3Ypy5eXuzpElI2/EkREpvjuO2DFCql/+ilQs+Yjdzl/XrqXALlO5Evk1JiMEBGZYsYMKYcNk5l58+CnnwCDQeqZ3UyIKBuTESKivLp0ScYVAYDhw/O829dfSzl5MlCihPnDIrJ3TEaIiPJCKWD2bCkbNgQqVcrTbocOAX/+KfWePS0WHZFdYzJCRJQXS5cC778vdRNmuPvpJynbtctz/kLkdJiMEBE9ilLAtGlS7949TwOcZdqxQ8o2bSwQF5GDYDJCRPQoU6bI/ZZixYDPPnvorLz3OnBAngAGgLZtLRcekb1jMkJE9DAzZwLvvSf1SZNk1LI8yuzr2rYtEBxs/tCIHAWTESKi3Oj18ggMANSqBYwYYdLuO3dK+YiR4omcHpMRIqKc3LwJvPACcP06UKZM9oR4Jti4Uco8jItG5NTyduOTiMiZ3LkDNG4MxMZK/5Dly026PQMAc+cC165JvVcvC8RI5EDYMkJEdC+9Xmayi42V5c8/N7n3qVLABx9IvWtXoHhxM8dI5GCYjBAR3SsyEjhzRuoffQT072/yIVatAuLi5K7O0qVmjo/IATEZISLKtGYNMHWq1IcPB0aNytdhdu2SMjCQw78T5QWTESIiAPjtt+yRVZs3lxl58+mHH6QMDzdDXEROgMkIETk3peS+SvfuUu/cGdi61eQnZzIlJgInTkg9NNR8YRI5Mj5NQ0TOrWdP4NtvpV6+PLBgAeDunu/D/fuvlOXKASVLmiE+IifAlhEicl5z5mQnIv36AUePAn5+BTpkdLSU1aoVMDYiJ8JkhIicT0aGjKw6ZIgsv/km8OWXZnkGN3MI+MceK/ChiJwGkxEicj7r1gETJki9Sxdg9myzHfqnn6Rs2dJshyRyeExGiMj5zJwpZe3awHffATqdWQ5761Z2vWNHsxySyCkwGSEi55GeDgwYIDPYubgAy5YBbm5mO3zmUzTe3jKIKxHlDZ+mISLnoJTckomKkuVPPgEaNjTrKY4fl7JqVbM1thA5BSYjROQcli/PTkRmzJARVs0s87FePklDZBrepiEix7d7NzB0qNQjIiySiNy8mT1oa1CQ2Q9P5NDYMkJEjm3vXqBFCyAtDWjSBJg40SKn+e474MoVwMdHnhQmorxjywgROa7UVKBTJ0lEypcHNm4s0OiqD/O//0nZuTPg4WGRUxA5LCYjROS4xo8HEhKkHhUFlC5tkdNcvgzs2SP1AQMscgoih8ZkhIgc040bwMcfS33kSODxxy12qq1bpfT2BoKDLXYaIofFZISIHFPmwGaAxfqJAEBcHPD++1J//XU+0kuUH0xGiMjxnDoFLFki9X79gKJFLXIagwF47TXgyBFpFeEtGqL84dM0RORYbt4E2rQBzpyRJGTMGIud6pdfsm/RbNoE1KljsVMROTS2jBCR48hsqjhzRmbg/fNPoEoVi51u0yYpO3cGmjWz2GmIHF6+kpE5c+agUqVKKFy4MIKCgrBr165ct12wYAFatmyJkiVLomTJkggJCXno9kRE+XL9OvDii8DatYCrK7BiBVCvnkVP+csvUvbta9HTEDk8k5ORlStXIjw8HBMmTMC+ffvQoEEDhIaG4tKlSzluv23bNnTr1g2//vorYmJiEBAQgPbt2+PChQsFDp6ICABw4YLMTLdunSx//bWML2JByclAbKzUmze36KmIHJ5OKaVM2SEoKAhPPPEEZs+eDQAwGAwICAjA0KFDMXr06Efun5GRgZIlS2L27Nno1atXns6p1+vh5eWFpKQkeHp6mhIuETm6xESgXTtg3z5Z/vhjeZTXwpYskadnypcHzp2z+OmI7FJev79N6sCalpaGvXv3IiIiImudi4sLQkJCEBMTk6dj3Lx5E3fu3EGpUqVy3SY1NRWpqalZy3q93pQwiciZdO0qiYinpwxsZoWBPm7dAt57T+pDhlj8dEQOz6TbNFeuXEFGRgZ8fX2N1vv6+iI+Pj5Px3j33Xfh7++PkJCQXLeJjIyEl5dX1isgIMCUMInIWUyaBPz2m9RXrLDaiGM7dgDnzwNubsDgwVY5JZFDs+rTNFOnTsWKFSuwdu1aFC5cONftIiIikJSUlPU6xzZQIrrf0KHZg5mNGQN06GC1Uy9eLGWvXvLQDhEVjEm3aXx8fODq6oqEzLke7kpISICfn99D9/3kk08wdepUbNmyBfXr13/oth4eHvDgTFNElJvYWOBuvzUMHQp8+KHVTn3unDTCAEBYmNVOS+TQTGoZcXd3R2BgIKKjo7PWGQwGREdHI/ghzaMff/wx3n//fURFRaFJkyb5j5aI6Pp1oH17qVeubDzsuxVkji1Svjzw9NNWPTWRwzJ5BNbw8HD07t0bTZo0QdOmTTFz5kykpKSgT58+AIBevXqhXLlyiIyMBAB89NFHGD9+PJYtW4ZKlSpl9S0pXrw4irN9k4hMcfw40L07cPYsUKYMEB0NuFh37MaoKCkrV7bqaYkcmsnJSFhYGC5fvozx48cjPj4eDRs2RFRUVFan1rNnz8Llnj8On3/+OdLS0vDSSy8ZHWfChAmYaMHJq4jIwaxZA/TuLcO9Fy8OrFpl9Yzg+nXgxx+lPmOGVU9N5NBMHmdECxxnhMjJXb8OBAQAKSnAU08BixYBlSpZPYy33wamTwcaNAAOHLD66YnsjkXGGSEi0sQ330giUqsWsGWLDPduZX//nd09ZcIEq5+eyKFxojwisl137gAffCBNEgAwcKAmiUhaGtCqlczD16YN8OyzVg+ByKGxZYSIbNeHH8rAZoDMxjtsmCZhDBsGZA4E/eWXVu8zS+Tw+CtFRLYpPh747DOp9+kjk99pkAWcPQt88YXUP/gAqFLF6iEQOTwmI0Rke27ckPsh164BtWsD8+cDOp3Vw1AKeOstKVu0AMaOtXoIRE6ByQgR2Zb0dODll4F//5XlH38ECmlzR3n1auCHH2QOmsmTNQmByCkwGSEi2zJ/fvbIYosWAdWqaRJGRgbwzjtSf+cdaaghIstgMkJEtuP6deD996U+dizw+uuahbJsmcxDAwB9+2oWBpFTYDJCRNozGKQVpGlT6bjq65vdLKGB6GhgyBCpt26tWeMMkdPgo71EpK0bN2TGuT//lGUfH2DzZsDbW5NwEhKAzp2B27dlbJH16zUJg8ipsGWEiLSjlAxklpmITJoE7NsH1KunWUiLFkkiUqSI9J3lfJ5ElseWESLSRkYG0K8fsHy5LM+YAQwfrmlIcXFARITUJ04EvLw0DYfIaTAZISLrO3lSeoVu3y7LkZGaJyIZGUCzZtnLAwZoFwuRs+FtGiKyntRUYMMGoG1bSUQKFwZmzQJGj9Y6MsyYIaOtAjLYa8mS2sZD5EzYMkJElqeU9Af59FPpsAoA5cpJR9XatbWNDfJE8fjxUh89GujZU9t4iJwNW0aIyLKUAkaOlGTkxg3Az0/GWN+3zyYSEUCGNLl1C6heHZgyRetoiJwPW0aIyLImT5YWEQD4/HN5esaGpr3duTN7IryXX9ZkChwip8dkhIgs49w5GTnsxx9lecIE4I03tI3pPhcvyoTABgPQvbvMyktE1sdkhIjM77ffgOefB65elVnmhg0D3ntP66iMKCW3ZW7eBEqXBubMYasIkVZsp62UiBzDhAkydOnVq0DjxsCBA8C0aZrNvJublSslEQE0HfCViMBkhIjMafVq6SMCAF26SAtJnTraxpSDr78GunWTeufOQIMG2sZD5OyYjBBRwaWlAWPGAK++KsvduklfkaJFtY3rPikpMhFw796y3KyZJCZEpC3bajclIvtz8ybw4otAVJQst28PLFigbUw5iImR8UNOnJDlsDDg228BV1dt4yIitowQUUHs3Ak89ZQkIkWLAkuWSL1YMa0jM7J+PfDkk5KIlC8P/PILsGIFExEiW8GWESIynV4vk9ytWSPLXl7Axo3yjW9DlAIWLpQnjAEgOBjYtIkT4BHZGraMEJFpEhNlUI7MRGTAAODvv20uEQGAmTOB/v2B27eBNm2khYSJCJHtYcsIEeXdzp0yTGl8vNzj+PFHoGNHraPK1caNUr76qnRUdXPTNh4iyhlbRojo0a5dk6dlWraURKRGDZtPRO7cAXbtkvrbbzMRIbJlbBkhotzduSMzx338cfYIYcHBMkqYjXVSvd/HH0vXlsceAxo10joaInoYJiNElLPt22VAjjNnZLl4cWD8eGDECJsbTfV+ffsCixZJ/d13+dQMka2z7b8oRGR9SUnA6NHAvHmyXLgw8L//SadVG28NAeTpmcxEZNw44J13tI2HiB6NyQgRievXJemYODF73UsvAbNnA76+moVlim++kSeOAcmdMkemJyLbxmSEyNldvQpMmiTNCTduyLrHHpOZ5Nq10zY2E6xYAfTqJfUuXWT8NSKyD0xGiJxRRob0CYmJAebPB86elfWVKkli8sIL0kfEDhgMwJtvAl98IcuNG8sQKDberYWI7sFfVyJnYzAAr70mTQmZAgKA998HevSwq29xpaSzamYryHPPyXgi7u6ahkVEJrKfvzpElH9KAUeOyKQsn3wCxMXJ+nr1gKFDJQmxsRl282LatOxEZO5cYNAgbeMhovxhMkLk6K5fB4YPlyaDe731FjBjBuBin2MfRkbKOGwAMHgwExEie8ZkhMhRpaYCI0cCS5dKQgIAISHAM89In5DKlbWNrwD++is7EenYEZg1S9t4iKhgmIwQOZqkJOmcOm4ccPCgrKtRQ8YO6dNH29gKSCm5yzR2rCxXrAj89JPdNu4Q0V1MRogcxc2bMtzoF1/IMO6ADD06YoQM6W7nk7PcugV06gT8+qssN2ki0+MwESGyf0xGiBzBn39Kp4kDB2S5TBkgLAwID5fHde3ckSNA+/bAxYuyPGsWMGQIExEiR8FkhMieJSRI59R7H9Ndvhx49VXNQrKE116TRKRIERmLrUsXrSMiInNiMkJkb+LigHXrpGPq7t3Zt2Sefx6YMAFo0EDT8Mztm2+yG3wOHJDuL0TkWJiMENmL/fvlEZKoKOP1jRpJr862bbWJy0IOH5aP+9NPshwRwUSEyFExGSGydZs3A+PHS7+QTPXqySR2zz4L1K8P6HTaxWcB06YBo0ZlL7/5pvH8fUTkWJiMENkipYC//5YZc7/6Sta5uABdu8oju40baxufhVy9Ko/tZs4zExgonVWbN9c2LiKyLCYjRLbCYJDWjzVrgPXrgf/+y36vXz9JQipW1C4+C9u6VcZkU0qWW7eWRiE7fyKZiPKAyQiRlpSSmXNXrpQE5OTJ7PcKFZJv5Ndfl8dJHFRKioxKP358diLy2WcyxLuD3X0iolwwGSHSQkKCzOw2fz4QH5+93tNTBtR46SWgVSugbFntYrSgO3eAY8dkkrt584DkZFnfogWwcSNQooS28RGRdTEZIbKGlBT5ll2/Hjh0SJ6MyVSkiHREfeklIDQUKF5cuzgtbMcOYOFCYNUqGVE1U+HCMnjsqFF2OXkwERUQkxEiS1BK5oXZuhXYskVeaWnG2zzxBPDOO5KIeHhoE6cVnDkDLFoE/PBD9nghgORc9esDr7wCDBwoORkROScmI0TmoBRw7hywdy+wYYO87r39AgBVqgAvvww0bSqJSECANrFawYULwMyZMibb9u3Z611cZHDYIUOAZs3YJ4SIBJMRorxSCrh0CTh+HIiNlU4Px47JUy+nTsmtmHsVKSIdUFu0kNFRa9d26G/fCxdkJPpdu4BNm4AbN7Lfa9QIGDpUJrorU0a7GInINjEZIbrX7dvyRMv+/TIE6MWL8rpwATh79sGE416FCskQoSEhQOfO0gHVgW+/3LwJ/PGHTIuzfbvkaPeqW1cmDG7UCGjY0KHzMCIqICYj5DyUAhITJak4fVoSjIQE6dRw8qS8Ll7Mfr40JzodUKGCJB2Zr+rVgapVgcqVHXZQDKXkUm3fLtPiHDggjUIGg/F2zZvLuGzBwVLnrLpElBdMRsi+ZWQA167JN+WlSw++7l1//vyDnUhzktmzskEDGWSsXDnA3x8oXx6oVAlwd7f4x9KSUtLdZcsW6QJz+LD0xb18+cFty5aVWy8vvijdYB57zPrxEpH9y1cyMmfOHEybNg3x8fFo0KABPvvsMzRt2jTX7VevXo1x48bh9OnTqF69Oj766CN07Ngx30GTg8vIkA4HyclAUpIkFPHxxq+4OODIEXnv/v+eP0rp0pJUlC8vHRgCAqRzaebLx8eh7ykoBVy5IvmZXi/lyZPAP//I6+hR4Pr1B/dzcZFGoBdekDtQDRsCfn5WD5+IHJDJycjKlSsRHh6OefPmISgoCDNnzkRoaChiY2NRJoeeaX/88Qe6deuGyMhIdO7cGcuWLcNzzz2Hffv2oW7dumb5EGQDMjLkm02vlyTi/jKv9eTkh/fLyIlOJ/8lL1PG+OXra7zs7y//lXfAfhxpadJAdPWqvC5fNm4cSkiQfrYJCXKn6lENRDodUKuWDHtSt640EtWpwzFAiMgydEo97Ab5g4KCgvDEE09g9uzZAACDwYCAgAAMHToUo0ePfmD7sLAwpKSkYP369VnrmjVrhoYNG2LevHl5Oqder4eXlxeSkpLg6elpSriOSylJAO7cAVJTH3zdvp3zekttY2oCkReFCsmIpL6+8l/w+1/Vq0s/DR8f2daOKCWDft2bg+WUl+XlpdcbDyCWV489Bnh5SVmxojzsU6eOlDVqcNwPIiq4vH5/m/QXPC0tDXv37kVERETWOhcXF4SEhCAmJibHfWJiYhAeHm60LjQ0FOvWrcv1PKmpqUhNTc1a1uv1poSZZzOf3YpTp3R3OywqKAMAKFlWAJS6+5bK6tRovKzubn53+7v7Zm+Tl3V3lw1KjmPIPL/h7jrIbQiDQfYx3H3vnhxSwfiWwv3LD65zhUIxAMVM3C8P27i4Qrm5SUfOQndLNzeoe+pwKwRVyB1wK3TPNoWgXKXM3Fe5uAD3HtsAqAsALtx3fvXw5bxuYzBIfpdZ5vYqyPuZ76Wnm3536VF0OqBUKUkufHyyG4Yyy8w7U97e0kDk4F1fiMiOmJSMXLlyBRkZGfD19TVa7+vri3///TfHfeLj43PcPv7+AaHuERkZiUmTJpkSWr6s2loaMTfqWfw8TsUAIPXui/KkeHGZi6VECWkIyqyb8vL2lhefXiEie2STbdsRERFGrSl6vR4BFhitsvdzSWh7Zocs6HSATgediy6782LWurt16O7bDjmsy943a1l3z7F0965zubsOgItOvkl0LrKfi4vRS+fqArhmLrtmr3NxAQq5yrrMY90jp36YeVmX3/3MeSxrx6DTAa6u2S8XF+Pl3F752a5QIUlCihVjAkFEZFIy4uPjA1dXVyQkJBitT0hIgF8u3er9/PxM2h4APDw84GGFTob/900Li5+DiIiIHs6k/5O5u7sjMDAQ0dHRWesMBgOio6MRHByc4z7BwcFG2wPA5s2bc92eiIiInIvJt2nCw8PRu3dvNGnSBE2bNsXMmTORkpKCPn36AAB69eqFcuXKITIyEgAwbNgwtG7dGp9++ik6deqEFStWYM+ePZg/f755PwkRERHZJZOTkbCwMFy+fBnjx49HfHw8GjZsiKioqKxOqmfPnoXLPTfBn3zySSxbtgzvvfcexowZg+rVq2PdunUcY4SIiIgA5GOcES1wnBEiIiL7k9fvb/bjJyIiIk0xGSEiIiJNMRkhIiIiTTEZISIiIk0xGSEiIiJNMRkhIiIiTTEZISIiIk0xGSEiIiJNMRkhIiIiTZk8HLwWMgeJ1ev1GkdCREREeZX5vf2owd7tIhlJTk4GAAQEBGgcCREREZkqOTkZXl5eub5vF3PTGAwGXLx4ESVKlIBOpzPbcfV6PQICAnDu3DnOefMIvFam4fXKO16rvOO1yjteq7yz5LVSSiE5ORn+/v5Gk+jezy5aRlxcXFC+fHmLHd/T05M/rHnEa2UaXq+847XKO16rvOO1yjtLXauHtYhkYgdWIiIi0hSTESIiItKUUycjHh4emDBhAjw8PLQOxebxWpmG1yvveK3yjtcq73it8s4WrpVddGAlIiIix+XULSNERESkPSYjREREpCkmI0RERKQpJiNERESkKadORubMmYNKlSqhcOHCCAoKwq5du7QOyaomTpwInU5n9KpVq1bW+7dv38bgwYPx2GOPoXjx4njxxReRkJBgdIyzZ8+iU6dOKFq0KMqUKYORI0ciPT3d2h/FInbs2IEuXbrA398fOp0O69atM3pfKYXx48ejbNmyKFKkCEJCQvDff/8ZbXPt2jX06NEDnp6e8Pb2Rr9+/XDjxg2jbQ4ePIiWLVuicOHCCAgIwMcff2zpj2Z2j7pWr7/++gM/a88884zRNs5wrSIjI/HEE0+gRIkSKFOmDJ577jnExsYabWOu37tt27ahcePG8PDwQLVq1bB48WJLfzyzy8v1euqppx742XrjjTeMtnGG6/X555+jfv36WQOXBQcHY9OmTVnv2/zPlXJSK1asUO7u7mrhwoXqyJEjasCAAcrb21slJCRoHZrVTJgwQT3++OMqLi4u63X58uWs99944w0VEBCgoqOj1Z49e1SzZs3Uk08+mfV+enq6qlu3rgoJCVH79+9XGzduVD4+PioiIkKLj2N2GzduVGPHjlXff/+9AqDWrl1r9P7UqVOVl5eXWrdunfr7779V165dVeXKldWtW7eytnnmmWdUgwYN1J9//ql+++03Va1aNdWtW7es95OSkpSvr6/q0aOHOnz4sFq+fLkqUqSI+uKLL6z1Mc3iUdeqd+/e6plnnjH6Wbt27ZrRNs5wrUJDQ9WiRYvU4cOH1YEDB1THjh1VhQoV1I0bN7K2Mcfv3cmTJ1XRokVVeHi4+ueff9Rnn32mXF1dVVRUlFU/b0Hl5Xq1bt1aDRgwwOhnKykpKet9Z7leP/74o9qwYYM6duyYio2NVWPGjFFubm7q8OHDSinb/7ly2mSkadOmavDgwVnLGRkZyt/fX0VGRmoYlXVNmDBBNWjQIMf3EhMTlZubm1q9enXWuqNHjyoAKiYmRiklX0AuLi4qPj4+a5vPP/9ceXp6qtTUVIvGbm33f8EaDAbl5+enpk2blrUuMTFReXh4qOXLlyullPrnn38UALV79+6sbTZt2qR0Op26cOGCUkqpuXPnqpIlSxpdr3fffVfVrFnTwp/IcnJLRp599tlc93HWa3Xp0iUFQG3fvl0pZb7fu1GjRqnHH3/c6FxhYWEqNDTU0h/Jou6/XkpJMjJs2LBc93Hm61WyZEn15Zdf2sXPlVPepklLS8PevXsREhKStc7FxQUhISGIiYnRMDLr+++//+Dv748qVaqgR48eOHv2LABg7969uHPnjtE1qlWrFipUqJB1jWJiYlCvXj34+vpmbRMaGgq9Xo8jR45Y94NY2alTpxAfH290fby8vBAUFGR0fby9vdGkSZOsbUJCQuDi4oK//vora5tWrVrB3d09a5vQ0FDExsbi+vXrVvo01rFt2zaUKVMGNWvWxKBBg3D16tWs95z1WiUlJQEASpUqBcB8v3cxMTFGx8jcxt7/vt1/vTJ9++238PHxQd26dREREYGbN29mveeM1ysjIwMrVqxASkoKgoOD7eLnyi4myjO3K1euICMjw+iiA4Cvry/+/fdfjaKyvqCgICxevBg1a9ZEXFwcJk2ahJYtW+Lw4cOIj4+Hu7s7vL29jfbx9fVFfHw8ACA+Pj7Ha5j5niPL/Hw5ff57r0+ZMmWM3i9UqBBKlSpltE3lypUfOEbmeyVLlrRI/Nb2zDPP4IUXXkDlypVx4sQJjBkzBh06dEBMTAxcXV2d8loZDAYMHz4czZs3R926dQHAbL93uW2j1+tx69YtFClSxBIfyaJyul4A0L17d1SsWBH+/v44ePAg3n33XcTGxuL7778H4FzX69ChQwgODsbt27dRvHhxrF27FnXq1MGBAwds/ufKKZMREh06dMiq169fH0FBQahYsSJWrVplN798ZB9effXVrHq9evVQv359VK1aFdu2bUO7du00jEw7gwcPxuHDh7Fz506tQ7ELuV2vgQMHZtXr1auHsmXLol27djhx4gSqVq1q7TA1VbNmTRw4cABJSUlYs2YNevfuje3bt2sdVp445W0aHx8fuLq6PtCTOCEhAX5+fhpFpT1vb2/UqFEDx48fh5+fH9LS0pCYmGi0zb3XyM/PL8drmPmeI8v8fA/7GfLz88OlS5eM3k9PT8e1a9ec/hpWqVIFPj4+OH78OADnu1ZDhgzB+vXr8euvv6J8+fJZ6831e5fbNp6ennb5H43crldOgoKCAMDoZ8tZrpe7uzuqVauGwMBAREZGokGDBpg1a5Zd/Fw5ZTLi7u6OwMBAREdHZ60zGAyIjo5GcHCwhpFp68aNGzhx4gTKli2LwMBAuLm5GV2j2NhYnD17NusaBQcH49ChQ0ZfIps3b4anpyfq1Klj9fitqXLlyvDz8zO6Pnq9Hn/99ZfR9UlMTMTevXuzttm6dSsMBkPWH8zg4GDs2LEDd+7cydpm8+bNqFmzpt3ddjDF+fPncfXqVZQtWxaA81wrpRSGDBmCtWvXYuvWrQ/cdjLX711wcLDRMTK3sbe/b4+6Xjk5cOAAABj9bDnL9bqfwWBAamqqffxcFbgLrJ1asWKF8vDwUIsXL1b//POPGjhwoPL29jbqSezo3n77bbVt2zZ16tQp9fvvv6uQkBDl4+OjLl26pJSSR8EqVKigtm7dqvbs2aOCg4NVcHBw1v6Zj4K1b99eHThwQEVFRanSpUs7zKO9ycnJav/+/Wr//v0KgJo+fbrav3+/OnPmjFJKHu319vZWP/zwgzp48KB69tlnc3y0t1GjRuqvv/5SO3fuVNWrVzd6XDUxMVH5+vqqnj17qsOHD6sVK1aookWL2tXjqko9/FolJyerd955R8XExKhTp06pLVu2qMaNG6vq1aur27dvZx3DGa7VoEGDlJeXl9q2bZvRo6g3b97M2sYcv3eZj2COHDlSHT16VM2ZM8fuHlVV6tHX6/jx42ry5Mlqz5496tSpU+qHH35QVapUUa1atco6hrNcr9GjR6vt27erU6dOqYMHD6rRo0crnU6nfvnlF6WU7f9cOW0yopRSn332mapQoYJyd3dXTZs2VX/++afWIVlVWFiYKlu2rHJ3d1flypVTYWFh6vjx41nv37p1S7355puqZMmSqmjRour5559XcXFxRsc4ffq06tChgypSpIjy8fFRb7/9trpz5461P4pF/PrrrwrAA6/evXsrpeTx3nHjxilfX1/l4eGh2rVrp2JjY42OcfXqVdWtWzdVvHhx5enpqfr06aOSk5ONtvn7779VixYtlIeHhypXrpyaOnWqtT6i2TzsWt28eVO1b99elS5dWrm5uamKFSuqAQMGPJD4O8O1yukaAVCLFi3K2sZcv3e//vqratiwoXJ3d1dVqlQxOoe9eNT1Onv2rGrVqpUqVaqU8vDwUNWqVVMjR440GmdEKee4Xn379lUVK1ZU7u7uqnTp0qpdu3ZZiYhStv9zpVNKqYK3rxARERHlj1P2GSEiIiLbwWSEiIiINMVkhIiIiDTFZISIiIg0xWSEiIiINMVkhIiIiDTFZISIiIg0xWSEiIiINMVkhIiIiDTFZISIiIg0xWSEiIiINMVkhIiIiDT1/+A6fWa5zfvYAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "n_samples = 1000\n",
    "n_classes = 3\n",
    "\n",
    "t = (torch.rand(n_samples, n_classes) - .5) * 10\n",
    "probas = F.softmax(t, -1)\n",
    "sharpened_probas = Sharpen()(probas)\n",
    "plt.plot(probas.flatten().sort().values, color='r')\n",
    "plt.plot(sharpened_probas.flatten().sort().values, color='b')\n",
    "plt.show()\n",
    "test_gt(sharpened_probas[n_samples//2:].max(-1).values.sum().item(), probas[n_samples//2:].max(-1).values.sum().item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class Sequential(nn.Sequential):\n",
    "    \"\"\"Class that allows you to pass one or multiple inputs\"\"\"\n",
    "    def forward(self, *x):\n",
    "        for i, module in enumerate(self._modules.values()): \n",
    "            x = module(*x) if isinstance(x, (list, tuple, L)) else module(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TimeDistributed(nn.Module):\n",
    "    def __init__(self, module, batch_first=False):\n",
    "        super(TimeDistributed, self).__init__()\n",
    "        self.module = module\n",
    "        self.batch_first = batch_first\n",
    "\n",
    "    def forward(self, x):\n",
    "\n",
    "        if len(x.size()) <= 2:\n",
    "            return self.module(x)\n",
    "\n",
    "        # Squash samples and timesteps into a single axis\n",
    "        x_reshape = x.contiguous().view(-1, x.size(-1))  # (samples * timesteps, input_size)\n",
    "\n",
    "        y = self.module(x_reshape)\n",
    "\n",
    "        # We have to reshape Y\n",
    "        if self.batch_first:\n",
    "            y = y.contiguous().view(x.size(0), -1, y.size(-1))  # (samples, timesteps, output_size)\n",
    "        else:\n",
    "            y = y.view(-1, x.size(1), y.size(-1))  # (timesteps, samples, output_size)\n",
    "\n",
    "        return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class Temp_Scale(Module):\n",
    "    \"Used to perform Temperature Scaling (dirichlet=False) or Single-parameter Dirichlet calibration (dirichlet=True)\"\n",
    "    def __init__(self, temp=1., dirichlet=False):\n",
    "        self.weight = nn.Parameter(tensor(temp))\n",
    "        self.bias = None\n",
    "        self.log_softmax = dirichlet\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.log_softmax: x = F.log_softmax(x, dim=-1)\n",
    "        return x.div(self.weight)\n",
    "\n",
    "\n",
    "class Vector_Scale(Module):\n",
    "    \"Used to perform Vector Scaling (dirichlet=False) or Diagonal Dirichlet calibration (dirichlet=True)\"\n",
    "    def __init__(self, n_classes=1, dirichlet=False):\n",
    "        self.weight = nn.Parameter(torch.ones(n_classes))\n",
    "        self.bias = nn.Parameter(torch.zeros(n_classes))\n",
    "        self.log_softmax = dirichlet\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.log_softmax: x = F.log_softmax(x, dim=-1)\n",
    "        return x.mul(self.weight).add(self.bias)\n",
    "\n",
    "\n",
    "class Matrix_Scale(Module):\n",
    "    \"Used to perform Matrix Scaling (dirichlet=False) or Dirichlet calibration (dirichlet=True)\"\n",
    "    def __init__(self, n_classes=1, dirichlet=False):\n",
    "        self.ms = nn.Linear(n_classes, n_classes)\n",
    "        self.ms.weight.data = nn.Parameter(torch.eye(n_classes))\n",
    "        nn.init.constant_(self.ms.bias.data, 0.)\n",
    "        self.weight = self.ms.weight\n",
    "        self.bias = self.ms.bias\n",
    "        self.log_softmax = dirichlet\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.log_softmax: x = F.log_softmax(x, dim=-1)\n",
    "        return self.ms(x)\n",
    "    \n",
    "    \n",
    "def get_calibrator(calibrator=None, n_classes=1, **kwargs):\n",
    "    if calibrator is None or not calibrator: return noop\n",
    "    elif calibrator.lower() == 'temp': return Temp_Scale(dirichlet=False, **kwargs)\n",
    "    elif calibrator.lower() == 'vector': return Vector_Scale(n_classes=n_classes, dirichlet=False, **kwargs)\n",
    "    elif calibrator.lower() == 'matrix': return Matrix_Scale(n_classes=n_classes, dirichlet=False, **kwargs)\n",
    "    elif calibrator.lower() == 'dtemp': return Temp_Scale(dirichlet=True, **kwargs)\n",
    "    elif calibrator.lower() == 'dvector': return Vector_Scale(n_classes=n_classes, dirichlet=True, **kwargs)\n",
    "    elif calibrator.lower() == 'dmatrix': return Matrix_Scale(n_classes=n_classes, dirichlet=True, **kwargs)\n",
    "    else: assert False, f'please, select a correct calibrator instead of {calibrator}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "c_out = 3\n",
    "\n",
    "t = torch.rand(bs, c_out)\n",
    "for calibrator, cal_name in zip(['temp', 'vector', 'matrix'], ['Temp_Scale', 'Vector_Scale', 'Matrix_Scale']): \n",
    "    cal = get_calibrator(calibrator, n_classes=c_out)\n",
    "#     print(calibrator)\n",
    "#     print(cal.weight, cal.bias, '\\n')\n",
    "    test_eq(cal(t), t)\n",
    "    test_eq(cal.__class__.__name__, cal_name)\n",
    "for calibrator, cal_name in zip(['dtemp', 'dvector', 'dmatrix'], ['Temp_Scale', 'Vector_Scale', 'Matrix_Scale']):\n",
    "    cal = get_calibrator(calibrator, n_classes=c_out)\n",
    "#     print(calibrator)\n",
    "#     print(cal.weight, cal.bias, '\\n')\n",
    "    test_eq(cal(t), F.log_softmax(t, dim=1))\n",
    "    test_eq(cal.__class__.__name__, cal_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "c_out = 3\n",
    "\n",
    "t = torch.rand(bs, c_out)\n",
    "\n",
    "test_eq(Temp_Scale()(t).shape, t.shape)\n",
    "test_eq(Vector_Scale(c_out)(t).shape, t.shape)\n",
    "test_eq(Matrix_Scale(c_out)(t).shape, t.shape)\n",
    "test_eq(Temp_Scale(dirichlet=True)(t).shape, t.shape)\n",
    "test_eq(Vector_Scale(c_out, dirichlet=True)(t).shape, t.shape)\n",
    "test_eq(Matrix_Scale(c_out, dirichlet=True)(t).shape, t.shape)\n",
    "\n",
    "test_eq(Temp_Scale()(t), t)\n",
    "test_eq(Vector_Scale(c_out)(t), t)\n",
    "test_eq(Matrix_Scale(c_out)(t), t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "c_out = 5\n",
    "\n",
    "t = torch.rand(bs, c_out)\n",
    "test_eq(Vector_Scale(c_out)(t), t)\n",
    "test_eq(Vector_Scale(c_out).weight.data, torch.ones(c_out))\n",
    "test_eq(Vector_Scale(c_out).weight.requires_grad, True)\n",
    "test_eq(type(Vector_Scale(c_out).weight), torch.nn.parameter.Parameter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "c_out = 3\n",
    "weight = 2\n",
    "bias = 1\n",
    "\n",
    "t = torch.rand(bs, c_out)\n",
    "test_eq(Matrix_Scale(c_out)(t).shape, t.shape)\n",
    "test_eq(Matrix_Scale(c_out).weight.requires_grad, True)\n",
    "test_eq(type(Matrix_Scale(c_out).weight), torch.nn.parameter.Parameter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class LogitAdjustmentLayer(Module):\n",
    "    \"Logit Adjustment for imbalanced datasets\"\n",
    "    def __init__(self, class_priors):\n",
    "        self.class_priors = class_priors\n",
    "    def forward(self, x):\n",
    "        return x.add(self.class_priors)\n",
    "    \n",
    "LogitAdjLayer = LogitAdjustmentLayer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs, n_classes = 16, 3\n",
    "class_priors = torch.rand(n_classes)\n",
    "logits = torch.randn(bs, n_classes) * 2\n",
    "test_eq(LogitAdjLayer(class_priors)(logits), logits + class_priors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class PPV(Module):\n",
    "    def __init__(self, dim=-1): \n",
    "        self.dim = dim\n",
    "    def forward(self, x): \n",
    "        return torch.gt(x, 0).sum(dim=self.dim).float() / x.shape[self.dim]\n",
    "    def __repr__(self): return f'{self.__class__.__name__}(dim={self.dim})'\n",
    "    \n",
    "\n",
    "class PPAuc(Module):\n",
    "    def __init__(self, dim=-1): \n",
    "        self.dim = dim\n",
    "    def forward(self, x): \n",
    "        x = F.relu(x).sum(self.dim) / (abs(x).sum(self.dim) + 1e-8)\n",
    "        return x\n",
    "    def __repr__(self): return f'{self.__class__.__name__}(dim={self.dim})'\n",
    "    \n",
    "    \n",
    "class MaxPPVPool1d(Module):\n",
    "    \"Drop-in replacement for AdaptiveConcatPool1d - multiplies nf by 2\"\n",
    "    def forward(self, x):\n",
    "        _max = x.max(dim=-1).values\n",
    "        _ppv = torch.gt(x, 0).sum(dim=-1).float() / x.shape[-1]\n",
    "        return torch.cat((_max, _ppv), dim=-1).unsqueeze(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "nf = 5\n",
    "sl = 4\n",
    "\n",
    "t = torch.rand(bs, nf, sl)\n",
    "test_eq(MaxPPVPool1d()(t).shape, (bs, nf*2, 1))\n",
    "test_eq(MaxPPVPool1d()(t).shape, AdaptiveConcatPool1d(1)(t).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class AdaptiveWeightedAvgPool1d(Module):\n",
    "    '''Global Pooling layer that performs a weighted average along the temporal axis\n",
    "    \n",
    "    It can be considered as a channel-wise form of local temporal attention. Inspired by the paper: \n",
    "    Hyun, J., Seong, H., & Kim, E. (2019). Universal Pooling--A New Pooling Method for Convolutional Neural Networks. arXiv preprint arXiv:1907.11440.'''\n",
    "\n",
    "    def __init__(self, n_in, seq_len, mult=2, n_layers=2, ln=False, dropout=0.5, act=nn.ReLU(), zero_init=True):\n",
    "        layers = nn.ModuleList()\n",
    "        for i in range(n_layers):\n",
    "            inp_mult = mult if i > 0 else 1\n",
    "            out_mult = mult if i < n_layers -1 else 1\n",
    "            p = dropout[i] if is_listy(dropout) else dropout\n",
    "            layers.append(LinLnDrop(seq_len * inp_mult, seq_len * out_mult, ln=False, p=p, \n",
    "                                    act=act if i < n_layers-1 and n_layers > 1 else None))\n",
    "        self.layers = layers\n",
    "        self.softmax = SoftMax(-1)\n",
    "        if zero_init: init_lin_zero(self)\n",
    "\n",
    "    def forward(self, x):\n",
    "        wap = x\n",
    "        for l in self.layers: wap = l(wap)\n",
    "        wap = self.softmax(wap)\n",
    "        return torch.mul(x, wap).sum(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class GAP1d(Module):\n",
    "    \"Global Adaptive Pooling + Flatten\"\n",
    "    def __init__(self, output_size=1):\n",
    "        self.gap = nn.AdaptiveAvgPool1d(output_size)\n",
    "        self.flatten = Reshape()\n",
    "    def forward(self, x):\n",
    "        return self.flatten(self.gap(x))\n",
    "    \n",
    "    \n",
    "class GACP1d(Module):\n",
    "    \"Global AdaptiveConcatPool + Flatten\"\n",
    "    def __init__(self, output_size=1):\n",
    "        self.gacp = AdaptiveConcatPool1d(output_size)\n",
    "        self.flatten = Reshape()\n",
    "    def forward(self, x):\n",
    "        return self.flatten(self.gacp(x))\n",
    "    \n",
    "\n",
    "class GAWP1d(Module):\n",
    "    \"Global AdaptiveWeightedAvgPool1d + Flatten\"\n",
    "    def __init__(self, n_in, seq_len, n_layers=2, ln=False, dropout=0.5, act=nn.ReLU(), zero_init=False):\n",
    "        self.gacp = AdaptiveWeightedAvgPool1d(n_in, seq_len, n_layers=n_layers, ln=ln, dropout=dropout, act=act, zero_init=zero_init)\n",
    "        self.flatten = Reshape()\n",
    "    def forward(self, x):\n",
    "        return self.flatten(self.gacp(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class GlobalWeightedAveragePool1d(Module):\n",
    "    \"\"\" Global Weighted Average Pooling layer \n",
    "    \n",
    "    Inspired by Building Efficient CNN Architecture for Offline Handwritten Chinese Character Recognition\n",
    "    https://arxiv.org/pdf/1804.01259.pdf\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(self, n_in, seq_len): \n",
    "        self.weight = nn.Parameter(torch.ones(1, n_in, seq_len))\n",
    "        self.bias = nn.Parameter(torch.zeros(1, n_in, seq_len))\n",
    "\n",
    "    def forward(self, x):\n",
    "        α = F.softmax(torch.sigmoid(x * self.weight + self.bias), dim=-1)\n",
    "        return (x * α).sum(-1)\n",
    "\n",
    "GWAP1d = GlobalWeightedAveragePool1d\n",
    "\n",
    "def gwa_pool_head(n_in, c_out, seq_len, bn=True, fc_dropout=0.):\n",
    "    return nn.Sequential(GlobalWeightedAveragePool1d(n_in, seq_len), Reshape(), LinBnDrop(n_in, c_out, p=fc_dropout, bn=bn))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.randn(16, 64, 50)\n",
    "head = gwa_pool_head(64, 5, 50)\n",
    "test_eq(head(t).shape, (16, 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class AttentionalPool1d(Module):\n",
    "    \"\"\"Global Adaptive Pooling layer inspired by Attentional Pooling for Action Recognition https://arxiv.org/abs/1711.01467\"\"\"\n",
    "    def __init__(self, n_in, c_out, bn=False): \n",
    "        store_attr()\n",
    "        self.bn = nn.BatchNorm1d(n_in) if bn else None\n",
    "        self.conv1 = Conv1d(n_in, 1, 1)\n",
    "        self.conv2 = Conv1d(n_in, c_out, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.bn is not None: x = self.bn(x) \n",
    "        return (self.conv1(x) @ self.conv2(x).transpose(1,2)).transpose(1,2)\n",
    "    \n",
    "class GAttP1d(nn.Sequential):\n",
    "    def __init__(self, n_in, c_out, bn=False):\n",
    "        super().__init__(AttentionalPool1d(n_in, c_out, bn=bn), Reshape())\n",
    "        \n",
    "def attentional_pool_head(n_in, c_out, seq_len=None, bn=True, **kwargs):\n",
    "    return nn.Sequential(AttentionalPool1d(n_in, c_out, bn=bn, **kwargs), Reshape())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs, c_in, seq_len = 16, 1, 50\n",
    "c_out = 3\n",
    "t = torch.rand(bs, c_in, seq_len)\n",
    "test_eq(GAP1d()(t).shape, (bs, c_in))\n",
    "test_eq(GACP1d()(t).shape, (bs, c_in*2))\n",
    "bs, c_in, seq_len = 16, 4, 50\n",
    "t = torch.rand(bs, c_in, seq_len)\n",
    "test_eq(GAP1d()(t).shape, (bs, c_in))\n",
    "test_eq(GACP1d()(t).shape, (bs, c_in*2))\n",
    "test_eq(GAWP1d(c_in, seq_len, n_layers=2, ln=False, dropout=0.5, act=nn.ReLU(), zero_init=False)(t).shape, (bs, c_in))\n",
    "test_eq(GAWP1d(c_in, seq_len, n_layers=2, ln=False, dropout=0.5, act=nn.ReLU(), zero_init=False)(t).shape, (bs, c_in))\n",
    "test_eq(GAWP1d(c_in, seq_len, n_layers=1, ln=False, dropout=0.5, zero_init=False)(t).shape, (bs, c_in))\n",
    "test_eq(GAWP1d(c_in, seq_len, n_layers=1, ln=False, dropout=0.5, zero_init=True)(t).shape, (bs, c_in))\n",
    "test_eq(AttentionalPool1d(c_in, c_out)(t).shape, (bs, c_out, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs, c_in, seq_len = 16, 128, 50\n",
    "c_out = 14\n",
    "t = torch.rand(bs, c_in, seq_len)\n",
    "attp = attentional_pool_head(c_in, c_out)\n",
    "test_eq(attp(t).shape, (bs, c_out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class PoolingLayer(Module):\n",
    "    def __init__(self, method='cls', seq_len=None, token=True, seq_last=True): \n",
    "        method = method.lower()\n",
    "        assert method in ['cls', 'max', 'mean', 'max-mean', 'linear', 'conv1d', 'flatten']\n",
    "        if method == 'cls': assert token, 'you can only choose method=cls if a token exists'\n",
    "        self.method = method\n",
    "        self.token = token\n",
    "        self.seq_last = seq_last\n",
    "        if method == 'linear' or method == 'conv1d':\n",
    "            self.linear = nn.Linear(seq_len - token, 1)\n",
    "\n",
    "    def forward(self, x): \n",
    "        if self.method == 'cls':\n",
    "            return x[..., 0] if self.seq_last else x[:, 0]\n",
    "        if self.token:\n",
    "            x = x[..., 1:] if self.seq_last else x[:, 1:] \n",
    "        if self.method == 'max':\n",
    "            return torch.max(x, -1)[0] if self.seq_last else torch.max(x, 1)[0]\n",
    "        elif self.method == 'mean':\n",
    "            return torch.mean(x, -1) if self.seq_last else torch.mean(x, 1)\n",
    "        elif self.method == 'max-mean':\n",
    "            return torch.cat([torch.max(x, -1)[0] if self.seq_last else torch.max(x, 1)[0],\n",
    "                              torch.mean(x, -1) if self.seq_last else torch.mean(x, 1)], 1)\n",
    "        elif self.method == 'flatten':\n",
    "            return x.flatten(1)\n",
    "        elif self.method == 'linear' or self.method == 'conv1d':\n",
    "            return self.linear(x)[...,0] if self.seq_last else self.linear(x.transpose(1,2))[...,0]\n",
    "    \n",
    "    def __repr__(self): return f\"{self.__class__.__name__}(method={self.method}, token={self.token}, seq_last={self.seq_last})\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.arange(24).reshape(2, 3, 4).float()\n",
    "test_eq(PoolingLayer('cls', token=True, seq_last=True)(t), tensor([[ 0.,  4.,  8.], [12., 16., 20.]]))\n",
    "test_eq(PoolingLayer('max', token=True, seq_last=True)(t), tensor([[ 3.,  7., 11.], [15., 19., 23.]]))\n",
    "test_close(PoolingLayer('mean', token=True, seq_last=True)(t), tensor([[ 2.,  6., 10.], [14., 18., 22.]]))\n",
    "test_close(PoolingLayer('max-mean', token=True, seq_last=True)(t), tensor([[ 3.,  7., 11.,  2.,  6., 10.],\n",
    "                                                                           [15., 19., 23., 14., 18., 22.]]))\n",
    "test_close(PoolingLayer('flatten', token=True, seq_last=True)(t), tensor([[ 1.,  2.,  3.,  5.,  6.,  7.,  9., 10., 11.],\n",
    "                                                                          [13., 14., 15., 17., 18., 19., 21., 22., 23.]]))\n",
    "test_eq(PoolingLayer('linear', seq_len=4, token=True, seq_last=True)(t).shape, (2, 3))\n",
    "test_eq(PoolingLayer('max', token=False, seq_last=True)(t), tensor([[ 3.,  7., 11.], [15., 19., 23.]]))\n",
    "test_close(PoolingLayer('mean', token=False, seq_last=True)(t), tensor([[ 1.5000,  5.5000,  9.5000],\n",
    "                                                                        [13.5000, 17.5000, 21.5000]]))\n",
    "test_close(PoolingLayer('max-mean', token=False, seq_last=True)(t), tensor([[ 3.,  7., 11.,  1.5000,  5.5000,  9.5000],\n",
    "                                                                            [15., 19., 23., 13.5000, 17.5000, 21.5000]]))\n",
    "test_close(PoolingLayer('flatten', token=False, seq_last=True)(t), tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.],\n",
    "                                                                           [12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23.]]))\n",
    "test_eq(PoolingLayer('linear', seq_len=4, token=False, seq_last=True)(t).shape, (2, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.arange(24).reshape(2, 3, 4).swapaxes(1,2).float()\n",
    "test_eq(PoolingLayer('cls', token=True, seq_last=False)(t), tensor([[ 0.,  4.,  8.], [12., 16., 20.]]))\n",
    "test_eq(PoolingLayer('max', token=True, seq_last=False)(t), tensor([[ 3.,  7., 11.], [15., 19., 23.]]))\n",
    "test_close(PoolingLayer('mean', token=True, seq_last=False)(t), tensor([[ 2.,  6., 10.], [14., 18., 22.]]))\n",
    "test_close(PoolingLayer('max-mean', token=True, seq_last=False)(t), tensor([[ 3.,  7., 11.,  2.,  6., 10.],\n",
    "                                                                           [15., 19., 23., 14., 18., 22.]]))\n",
    "test_close(PoolingLayer('flatten', token=True, seq_last=False)(t), tensor([[ 1.,  5.,  9.,  2.,  6., 10.,  3.,  7., 11.],\n",
    "                                                                           [13., 17., 21., 14., 18., 22., 15., 19., 23.]]))\n",
    "t = torch.arange(24).reshape(2, 3, 4).swapaxes(1,2).float()\n",
    "test_eq(PoolingLayer('conv1d', seq_len=4, token=False, seq_last=False)(t).shape, (2, 3))\n",
    "test_eq(PoolingLayer('max', token=False, seq_last=False)(t), tensor([[ 3.,  7., 11.], [15., 19., 23.]]))\n",
    "test_close(PoolingLayer('mean', token=False, seq_last=False)(t), tensor([[ 1.5000,  5.5000,  9.5000],\n",
    "                                                                        [13.5000, 17.5000, 21.5000]]))\n",
    "test_close(PoolingLayer('max-mean', token=False, seq_last=False)(t), tensor([[ 3.,  7., 11.,  1.5000,  5.5000,  9.5000],\n",
    "                                                                            [15., 19., 23., 13.5000, 17.5000, 21.5000]]))\n",
    "test_close(PoolingLayer('flatten', token=False, seq_last=False)(t), tensor([[ 0.,  4.,  8.,  1.,  5.,  9.,  2.,  6., 10.,  3.,  7., 11.],\n",
    "                                                                            [12., 16., 20., 13., 17., 21., 14., 18., 22., 15., 19., 23.]]))\n",
    "test_eq(PoolingLayer('conv1d', seq_len=4, token=False, seq_last=False)(t).shape, (2, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class GEGLU(Module):\n",
    "    def forward(self, x):\n",
    "        x, gates = x.chunk(2, dim=-1)\n",
    "        return x * F.gelu(gates)\n",
    "\n",
    "class ReGLU(Module):\n",
    "    def forward(self, x):\n",
    "        x, gates = x.chunk(2, dim=-1)\n",
    "        return x * F.relu(gates)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "pytorch_acts = [nn.ELU, nn.LeakyReLU, nn.PReLU, nn.ReLU, nn.ReLU6, nn.SELU, nn.CELU, nn.GELU, nn.Sigmoid, Mish, nn.Softplus,\n",
    "nn.Tanh, nn.Softmax, GEGLU, ReGLU, SmeLU]\n",
    "pytorch_act_names = [a.__name__.lower() for a in pytorch_acts]\n",
    "\n",
    "def get_act_fn(act, **act_kwargs):\n",
    "    if act is None: return\n",
    "    elif isinstance(act, nn.Module): return act\n",
    "    elif callable(act): return act(**act_kwargs)\n",
    "    idx = pytorch_act_names.index(act.lower())\n",
    "    return pytorch_acts[idx](**act_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(get_act_fn(nn.ReLU).__repr__(), \"ReLU()\")\n",
    "test_eq(get_act_fn(nn.ReLU()).__repr__(), \"ReLU()\")\n",
    "test_eq(get_act_fn(nn.LeakyReLU, negative_slope=0.05).__repr__(), \"LeakyReLU(negative_slope=0.05)\")\n",
    "test_eq(get_act_fn('reglu').__repr__(), \"ReGLU()\")\n",
    "test_eq(get_act_fn('leakyrelu', negative_slope=0.05).__repr__(), \"LeakyReLU(negative_slope=0.05)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class RevIN(nn.Module):\n",
    "    \"\"\" Reversible Instance Normalization layer adapted from\n",
    "\n",
    "        Kim, T., Kim, J., Tae, Y., Park, C., Choi, J. H., & Choo, J. (2021, September). \n",
    "        Reversible instance normalization for accurate time-series forecasting against distribution shift. \n",
    "        In International Conference on Learning Representations.\n",
    "        Original code: https://github.com/ts-kim/RevIN\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "         c_in:int,    # #features (aka variables or channels)\n",
    "         affine:bool=True,  # flag to incidate if RevIN has learnable weight and bias\n",
    "         subtract_last:bool=False,\n",
    "         dim:int=2,   # int or tuple of dimensions used to calculate mean and std\n",
    "         eps:float=1e-5  # epsilon - parameter added for numerical stability\n",
    "         ):\n",
    "        super().__init__()\n",
    "        self.c_in, self.affine, self.subtract_last, self.dim, self.eps = c_in, affine, subtract_last, dim, eps\n",
    "        if self.affine:\n",
    "            self.weight = nn.Parameter(torch.ones(1, c_in, 1))\n",
    "            self.bias = nn.Parameter(torch.zeros(1, c_in, 1))\n",
    "    \n",
    "    def forward(self, x:Tensor, mode:Tensor):\n",
    "        \"\"\"Args:\n",
    "\n",
    "            x: rank 3 tensor with shape [batch size x c_in x sequence length]\n",
    "            mode: torch.tensor(True) to normalize data and torch.tensor(False) to reverse normalization\n",
    "        \"\"\"\n",
    "        \n",
    "        # Normalize\n",
    "        if mode: return self.normalize(x)\n",
    "        \n",
    "        # Denormalize\n",
    "        else: return self.denormalize(x)\n",
    "           \n",
    "    def normalize(self, x):\n",
    "        if self.subtract_last:\n",
    "            self.sub = x[..., -1].unsqueeze(-1).detach()\n",
    "        else:\n",
    "            self.sub = torch.mean(x, dim=-1, keepdim=True).detach()\n",
    "        self.std = torch.std(x, dim=-1, keepdim=True, unbiased=False).detach() + self.eps\n",
    "        if self.affine:\n",
    "            x = x.sub(self.sub)\n",
    "            x = x.div(self.std)\n",
    "            x = x.mul(self.weight)\n",
    "            x = x.add(self.bias)\n",
    "            return x\n",
    "        else:\n",
    "            x = x.sub(self.sub)\n",
    "            x = x.div(self.std)\n",
    "            return x\n",
    "        \n",
    "    def denormalize(self, x):\n",
    "        if self.affine:\n",
    "            x = x.sub(self.bias)\n",
    "            x = x.div(self.weight)\n",
    "            x = x.mul(self.std)\n",
    "            x = x.add(self.sub)\n",
    "            return x\n",
    "        else:\n",
    "            x = x.mul(self.std)\n",
    "            x = x.add(self.sub)\n",
    "            return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class RevIN(nn.Module):\n",
    "    \"\"\" Reversible Instance Normalization layer adapted from\n",
    "\n",
    "        Kim, T., Kim, J., Tae, Y., Park, C., Choi, J. H., & Choo, J. (2021, September). \n",
    "        Reversible instance normalization for accurate time-series forecasting against distribution shift. \n",
    "        In International Conference on Learning Representations.\n",
    "        Original code: https://github.com/ts-kim/RevIN\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "         c_in:int,    # #features (aka variables or channels)\n",
    "         affine:bool=True,  # flag to incidate if RevIN has learnable weight and bias\n",
    "         subtract_last:bool=False,\n",
    "         dim:int=2,   # int or tuple of dimensions used to calculate mean and std\n",
    "         eps:float=1e-5  # epsilon - parameter added for numerical stability\n",
    "         ):\n",
    "        super().__init__()\n",
    "        self.c_in, self.affine, self.subtract_last, self.dim, self.eps = c_in, affine, subtract_last, dim, eps\n",
    "        self.weight = nn.Parameter(torch.ones(1, c_in, 1))\n",
    "        self.bias = nn.Parameter(torch.zeros(1, c_in, 1))\n",
    "        self.sub, self.std, self.mul, self.add = torch.zeros(1), torch.ones(1), torch.ones(1), torch.zeros(1)\n",
    "    \n",
    "    def forward(self, x:Tensor, mode:Tensor):\n",
    "        \"\"\"Args:\n",
    "\n",
    "            x: rank 3 tensor with shape [batch size x c_in x sequence length]\n",
    "            mode: torch.tensor(True) to normalize data and torch.tensor(False) to reverse normalization\n",
    "        \"\"\"\n",
    "        \n",
    "        # Normalize\n",
    "        if mode: \n",
    "            if self.subtract_last:\n",
    "                self.sub = x[..., -1].unsqueeze(-1).detach()\n",
    "            else:\n",
    "                self.sub = torch.mean(x, dim=-1, keepdim=True).detach()\n",
    "            self.std = torch.std(x, dim=-1, keepdim=True, unbiased=False).detach() + self.eps\n",
    "            if self.affine:\n",
    "                x = x.sub(self.sub)\n",
    "                x = x.div(self.std)\n",
    "                x = x.mul(self.weight)\n",
    "                x = x.add(self.bias)\n",
    "                return x\n",
    "            else:\n",
    "                x = x.sub(self.sub)\n",
    "                x = x.div(self.std)\n",
    "                return x\n",
    "        \n",
    "        # Denormalize\n",
    "        else: \n",
    "            if self.affine:\n",
    "                x = x.sub(self.bias)\n",
    "                x = x.div(self.weight)\n",
    "                x = x.mul(self.std)\n",
    "                x = x.add(self.sub)\n",
    "                return x\n",
    "            else:\n",
    "                x = x.mul(self.std)\n",
    "                x = x.add(self.sub)\n",
    "                return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = ((torch.rand(16, 5, 100) - .25) * torch.Tensor([.01, .1, 1, 10, 100]).reshape(1, -1, 1)).cumsum(-1)\n",
    "t_clone = t.clone()\n",
    "l = RevIN(5)\n",
    "t_norm = l(t, torch.tensor(True))\n",
    "t_denorm = l(t_norm, torch.tensor(False))\n",
    "test_close(t_clone, t_denorm, eps=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "scripting ok\n"
     ]
    }
   ],
   "source": [
    "model = RevIN(5, affine=True)\n",
    "try:\n",
    "    scripted_model = torch.jit.script(model)\n",
    "    file_path = f\"test_scripted_model.pt\"\n",
    "    torch.jit.save(scripted_model, file_path)\n",
    "    scripted_model = torch.jit.load(file_path)\n",
    "\n",
    "    inp = ((torch.rand(16, 5, 100) - .25) * torch.Tensor([.01, .1, 1, 10, 100]).reshape(1, -1, 1)).cumsum(-1)\n",
    "    normed_output = model(inp, torch.tensor(True))\n",
    "    demormed_output = model(normed_output, torch.tensor(False))\n",
    "    scripted_normed_output = scripted_model(inp, torch.tensor(True))\n",
    "    scripted_denormed_output = scripted_model(scripted_normed_output, torch.tensor(False))\n",
    "    test_close(normed_output, scripted_normed_output)\n",
    "    test_close(demormed_output, scripted_denormed_output)\n",
    "    os.remove(file_path)\n",
    "    del scripted_model\n",
    "    gc.collect()\n",
    "    print('scripting ok')\n",
    "except Exception as e:\n",
    "    print(f'scripting failed: {e}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def create_pool_head(n_in, c_out, seq_len=None, concat_pool=False, fc_dropout=0., bn=False, y_range=None, **kwargs):\n",
    "    if kwargs: print(f'{kwargs}  not being used')\n",
    "    if concat_pool: n_in*=2\n",
    "    layers = [GACP1d(1) if concat_pool else GAP1d(1)]\n",
    "    layers += [LinBnDrop(n_in, c_out, bn=bn, p=fc_dropout)]\n",
    "    if y_range: layers += [SigmoidRange(*y_range)]\n",
    "    return nn.Sequential(*layers)\n",
    "\n",
    "pool_head = create_pool_head\n",
    "average_pool_head = partial(pool_head, concat_pool=False)\n",
    "setattr(average_pool_head, \"__name__\", \"average_pool_head\")\n",
    "concat_pool_head = partial(pool_head, concat_pool=True)\n",
    "setattr(concat_pool_head, \"__name__\", \"concat_pool_head\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): GACP1d(\n",
       "    (gacp): AdaptiveConcatPool1d(\n",
       "      (ap): AdaptiveAvgPool1d(output_size=1)\n",
       "      (mp): AdaptiveMaxPool1d(output_size=1)\n",
       "    )\n",
       "    (flatten): Reshape(bs)\n",
       "  )\n",
       "  (1): LinBnDrop(\n",
       "    (0): BatchNorm1d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (1): Dropout(p=0.5, inplace=False)\n",
       "    (2): Linear(in_features=24, out_features=2, bias=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 12\n",
    "c_out = 2\n",
    "seq_len = 20\n",
    "t = torch.rand(bs, nf, seq_len)\n",
    "test_eq(create_pool_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))\n",
    "test_eq(create_pool_head(nf, c_out, seq_len, concat_pool=True, fc_dropout=0.5)(t).shape, (bs, c_out))\n",
    "create_pool_head(nf, c_out, seq_len, concat_pool=True, bn=True, fc_dropout=.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def max_pool_head(n_in, c_out, seq_len, fc_dropout=0., bn=False, y_range=None, **kwargs):\n",
    "    if kwargs: print(f'{kwargs}  not being used')\n",
    "    layers = [nn.MaxPool1d(seq_len, **kwargs), Reshape()]\n",
    "    layers += [LinBnDrop(n_in, c_out, bn=bn, p=fc_dropout)]\n",
    "    if y_range: layers += [SigmoidRange(*y_range)]\n",
    "    return nn.Sequential(*layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "nf = 12\n",
    "c_out = 2\n",
    "seq_len = 20\n",
    "t = torch.rand(bs, nf, seq_len)\n",
    "test_eq(max_pool_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def create_pool_plus_head(*args, lin_ftrs=None, fc_dropout=0., concat_pool=True, bn_final=False, lin_first=False, y_range=None):\n",
    "    nf = args[0]\n",
    "    c_out = args[1]\n",
    "    if concat_pool: nf = nf * 2\n",
    "    lin_ftrs = [nf, 512, c_out] if lin_ftrs is None else [nf] + lin_ftrs + [c_out]\n",
    "    ps = L(fc_dropout)\n",
    "    if len(ps) == 1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps\n",
    "    actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]\n",
    "    pool = AdaptiveConcatPool1d() if concat_pool else nn.AdaptiveAvgPool1d(1)\n",
    "    layers = [pool, Reshape()]\n",
    "    if lin_first: layers.append(nn.Dropout(ps.pop(0)))\n",
    "    for ni,no,p,actn in zip(lin_ftrs[:-1], lin_ftrs[1:], ps, actns):\n",
    "        layers += LinBnDrop(ni, no, bn=True, p=p, act=actn, lin_first=lin_first)\n",
    "    if lin_first: layers.append(nn.Linear(lin_ftrs[-2], c_out))\n",
    "    if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))\n",
    "    if y_range is not None: layers.append(SigmoidRange(*y_range))\n",
    "    return nn.Sequential(*layers)\n",
    "\n",
    "pool_plus_head = create_pool_plus_head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): AdaptiveConcatPool1d(\n",
       "    (ap): AdaptiveAvgPool1d(output_size=1)\n",
       "    (mp): AdaptiveMaxPool1d(output_size=1)\n",
       "  )\n",
       "  (1): Reshape(bs)\n",
       "  (2): BatchNorm1d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (3): Dropout(p=0.25, inplace=False)\n",
       "  (4): Linear(in_features=24, out_features=512, bias=False)\n",
       "  (5): ReLU(inplace=True)\n",
       "  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (7): Dropout(p=0.5, inplace=False)\n",
       "  (8): Linear(in_features=512, out_features=2, bias=False)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 12\n",
    "c_out = 2\n",
    "seq_len = 20\n",
    "t = torch.rand(bs, nf, seq_len)\n",
    "test_eq(create_pool_plus_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))\n",
    "test_eq(create_pool_plus_head(nf, c_out, concat_pool=True, fc_dropout=0.5)(t).shape, (bs, c_out))\n",
    "create_pool_plus_head(nf, c_out, seq_len, fc_dropout=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def create_conv_head(*args, adaptive_size=None, y_range=None):\n",
    "    nf = args[0]\n",
    "    c_out = args[1]\n",
    "    layers = [nn.AdaptiveAvgPool1d(adaptive_size)] if adaptive_size is not None else []\n",
    "    for i in range(2):\n",
    "        if nf > 1: \n",
    "            layers += [ConvBlock(nf, nf // 2, 1)] \n",
    "            nf = nf//2\n",
    "        else: break\n",
    "    layers += [ConvBlock(nf, c_out, 1), GAP1d(1)]\n",
    "    if y_range: layers += [SigmoidRange(*y_range)]\n",
    "    return nn.Sequential(*layers)\n",
    "\n",
    "conv_head = create_conv_head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): ConvBlock(\n",
       "    (0): Conv1d(12, 6, kernel_size=(1,), stride=(1,), bias=False)\n",
       "    (1): BatchNorm1d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): ReLU()\n",
       "  )\n",
       "  (1): ConvBlock(\n",
       "    (0): Conv1d(6, 3, kernel_size=(1,), stride=(1,), bias=False)\n",
       "    (1): BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): ReLU()\n",
       "  )\n",
       "  (2): ConvBlock(\n",
       "    (0): Conv1d(3, 2, kernel_size=(1,), stride=(1,), bias=False)\n",
       "    (1): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): ReLU()\n",
       "  )\n",
       "  (3): GAP1d(\n",
       "    (gap): AdaptiveAvgPool1d(output_size=1)\n",
       "    (flatten): Reshape(bs)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 12\n",
    "c_out = 2\n",
    "seq_len = 20\n",
    "t = torch.rand(bs, nf, seq_len)\n",
    "test_eq(create_conv_head(nf, c_out, seq_len)(t).shape, (bs, c_out))\n",
    "test_eq(create_conv_head(nf, c_out, adaptive_size=50)(t).shape, (bs, c_out))\n",
    "create_conv_head(nf, c_out, 50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def create_mlp_head(nf, c_out, seq_len=None, flatten=True, fc_dropout=0., bn=False, lin_first=False, y_range=None):\n",
    "    if flatten: nf *= seq_len\n",
    "    layers = [Reshape()] if flatten else []\n",
    "    layers += [LinBnDrop(nf, c_out, bn=bn, p=fc_dropout, lin_first=lin_first)]\n",
    "    if y_range: layers += [SigmoidRange(*y_range)]\n",
    "    return nn.Sequential(*layers)\n",
    "\n",
    "mlp_head = create_mlp_head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Reshape(bs)\n",
       "  (1): LinBnDrop(\n",
       "    (0): BatchNorm1d(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (1): Dropout(p=0.5, inplace=False)\n",
       "    (2): Linear(in_features=240, out_features=2, bias=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 12\n",
    "c_out = 2\n",
    "seq_len = 20\n",
    "t = torch.rand(bs, nf, seq_len)\n",
    "test_eq(create_mlp_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))\n",
    "t = torch.rand(bs, nf, seq_len)\n",
    "create_mlp_head(nf, c_out, seq_len, bn=True, fc_dropout=.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def create_fc_head(nf, c_out, seq_len=None, flatten=True, lin_ftrs=None, y_range=None, fc_dropout=0., bn=False, bn_final=False, act=nn.ReLU(inplace=True)):\n",
    "    if flatten: nf *= seq_len\n",
    "    layers = [Reshape()] if flatten else []\n",
    "    lin_ftrs = [nf, 512, c_out] if lin_ftrs is None else [nf] + lin_ftrs + [c_out]\n",
    "    if not is_listy(fc_dropout): fc_dropout = [fc_dropout]*(len(lin_ftrs) - 1)\n",
    "    actns = [act for _ in range(len(lin_ftrs) - 2)] + [None]\n",
    "    layers += [LinBnDrop(lin_ftrs[i], lin_ftrs[i+1], bn=bn and (i!=len(actns)-1 or bn_final), p=p, act=a) for i,(p,a) in enumerate(zip(fc_dropout+[0.], actns))]\n",
    "    if y_range is not None: layers.append(SigmoidRange(*y_range))\n",
    "    return nn.Sequential(*layers)\n",
    "\n",
    "fc_head = create_fc_head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Reshape(bs)\n",
       "  (1): LinBnDrop(\n",
       "    (0): BatchNorm1d(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (1): Dropout(p=0.5, inplace=False)\n",
       "    (2): Linear(in_features=240, out_features=2, bias=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 12\n",
    "c_out = 2\n",
    "seq_len = 20\n",
    "t = torch.rand(bs, nf, seq_len)\n",
    "test_eq(create_fc_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))\n",
    "create_mlp_head(nf, c_out, seq_len, bn=True, fc_dropout=.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def create_rnn_head(*args, fc_dropout=0., bn=False, y_range=None):\n",
    "    nf = args[0]\n",
    "    c_out = args[1]\n",
    "    layers = [LastStep()]\n",
    "    layers += [LinBnDrop(nf, c_out, bn=bn, p=fc_dropout)]\n",
    "    if y_range: layers += [SigmoidRange(*y_range)]\n",
    "    return nn.Sequential(*layers)\n",
    "\n",
    "rnn_head = create_rnn_head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): LastStep()\n",
       "  (1): LinBnDrop(\n",
       "    (0): BatchNorm1d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (1): Dropout(p=0.5, inplace=False)\n",
       "    (2): Linear(in_features=12, out_features=2, bias=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 12\n",
    "c_out = 2\n",
    "seq_len = 20\n",
    "t = torch.rand(bs, nf, seq_len)\n",
    "test_eq(create_rnn_head(nf, c_out, seq_len, fc_dropout=0.5)(t).shape, (bs, c_out))\n",
    "create_rnn_head(nf, c_out, seq_len, bn=True, fc_dropout=.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def imputation_head(c_in, c_out, seq_len=None, ks=1, y_range=None, fc_dropout=0.):\n",
    "    layers = [nn.Dropout(fc_dropout), nn.Conv1d(c_in, c_out, ks)]\n",
    "    if y_range is not None: \n",
    "        y_range = (tensor(y_range[0]), tensor(y_range[1]))\n",
    "        layers += [SigmoidRange(*y_range)]\n",
    "    return nn.Sequential(*layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Dropout(p=0.0, inplace=False)\n",
       "  (1): Conv1d(12, 2, kernel_size=(1,), stride=(1,))\n",
       "  (2): fastai.layers.SigmoidRange(low=tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.2000, 0.2000, 0.2000, 0.2000, 0.3000,\n",
       "          0.3000, 0.3000, 0.3000]), high=tensor([0.6000, 0.6000, 0.6000, 0.6000, 0.7000, 0.7000, 0.7000, 0.7000, 0.8000,\n",
       "          0.8000, 0.8000, 0.8000]))\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 12\n",
    "ni = 2\n",
    "seq_len = 20\n",
    "t = torch.rand(bs, nf, seq_len)\n",
    "head = imputation_head(nf, ni, seq_len=None, ks=1, y_range=None, fc_dropout=0.)\n",
    "test_eq(head(t).shape, (bs, ni, seq_len))\n",
    "head = imputation_head(nf, ni, seq_len=None, ks=1, y_range=(.3,.7), fc_dropout=0.)\n",
    "test_ge(head(t).min(), .3)\n",
    "test_le(head(t).max(), .7)\n",
    "y_range = (tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.2000, 0.2000, 0.2000, 0.2000, 0.3000,\n",
    "                   0.3000, 0.3000, 0.3000]),\n",
    "           tensor([0.6000, 0.6000, 0.6000, 0.6000, 0.7000, 0.7000, 0.7000, 0.7000, 0.8000,\n",
    "                   0.8000, 0.8000, 0.8000]))\n",
    "test_ge(head(t).min(), .1)\n",
    "test_le(head(t).max(), .9)\n",
    "head = imputation_head(nf, ni, seq_len=None, ks=1, y_range=y_range, fc_dropout=0.)\n",
    "head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class create_conv_lin_nd_head(nn.Sequential):\n",
    "    \"Module to create a nd output head\"\n",
    "\n",
    "    def __init__(self, n_in, n_out, seq_len, d, conv_first=True, conv_bn=False, lin_bn=False, fc_dropout=0., **kwargs):\n",
    "\n",
    "        assert d, \"you cannot use an nd head when d is None or 0\"\n",
    "        if is_listy(d):\n",
    "            fd = 1\n",
    "            shape = []\n",
    "            for _d in d:\n",
    "                fd *= _d\n",
    "                shape.append(_d)\n",
    "            if n_out > 1: shape.append(n_out)\n",
    "        else: \n",
    "            fd = d\n",
    "            shape = [d, n_out] if n_out > 1 else [d]\n",
    "        \n",
    "        conv = [BatchNorm(n_in, ndim=1)] if conv_bn else []\n",
    "        conv.append(Conv1d(n_in, n_out, 1, padding=0, bias=not conv_bn, **kwargs))\n",
    "        l = [Transpose(-1, -2), BatchNorm(seq_len, ndim=1), Transpose(-1, -2)] if lin_bn else []\n",
    "        if fc_dropout != 0: l.append(nn.Dropout(fc_dropout))\n",
    "        lin = [nn.Linear(seq_len, fd, bias=not lin_bn)]\n",
    "        lin_layers = l+lin\n",
    "        layers = conv + lin_layers if conv_first else lin_layers + conv\n",
    "        layers += [Transpose(-1,-2)]\n",
    "        layers += [Reshape(*shape)]\n",
    "\n",
    "        super().__init__(*layers)\n",
    "        \n",
    "conv_lin_nd_head = create_conv_lin_nd_head\n",
    "conv_lin_3d_head = create_conv_lin_nd_head # included for compatibility\n",
    "create_conv_lin_3d_head = create_conv_lin_nd_head # included for compatibility"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(TensorBase(1.7074, grad_fn=<AliasBackward0>),\n",
       " create_conv_lin_nd_head(\n",
       "   (0): Conv1d(32, 5, kernel_size=(1,), stride=(1,))\n",
       "   (1): Dropout(p=0.5, inplace=False)\n",
       "   (2): Linear(in_features=10, out_features=2, bias=True)\n",
       "   (3): Transpose(-1, -2)\n",
       "   (4): Reshape(bs, 2, 5)\n",
       " ))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 32\n",
    "c = 5\n",
    "seq_len = 10\n",
    "d = 2\n",
    "targ = torch.randint(0, c, (bs,d))\n",
    "t = torch.randn(bs, nf, seq_len)\n",
    "head = conv_lin_nd_head(nf, c, seq_len, d, conv_first=True, fc_dropout=.5)\n",
    "inp = head(t)\n",
    "test_eq(inp.shape, (bs, d, c))\n",
    "loss = CrossEntropyLossFlat()(inp, targ)\n",
    "loss, head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(TensorBase(1.6561, grad_fn=<AliasBackward0>),\n",
       " create_conv_lin_nd_head(\n",
       "   (0): Dropout(p=0.5, inplace=False)\n",
       "   (1): Linear(in_features=10, out_features=16, bias=True)\n",
       "   (2): Conv1d(32, 5, kernel_size=(1,), stride=(1,))\n",
       "   (3): Transpose(-1, -2)\n",
       "   (4): Reshape(bs, 2, 8, 5)\n",
       " ))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 32\n",
    "c = 5\n",
    "seq_len = 10\n",
    "d = [2, 8]\n",
    "targ = torch.randint(0, c, [bs]+d)\n",
    "t = torch.randn(bs, nf, seq_len)\n",
    "head = conv_lin_nd_head(nf, c, seq_len, d, conv_first=False, fc_dropout=.5)\n",
    "inp = head(t)\n",
    "test_eq(inp.shape, [bs]+d+[c])\n",
    "loss = CrossEntropyLossFlat()(inp, targ)\n",
    "loss, head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(TensorBase(0.6017, grad_fn=<AliasBackward0>),\n",
       " create_conv_lin_nd_head(\n",
       "   (0): Dropout(p=0.5, inplace=False)\n",
       "   (1): Linear(in_features=10, out_features=2, bias=True)\n",
       "   (2): Conv1d(32, 1, kernel_size=(1,), stride=(1,))\n",
       "   (3): Transpose(-1, -2)\n",
       "   (4): Reshape(bs, 2)\n",
       " ))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 32\n",
    "c = 1\n",
    "seq_len = 10\n",
    "d = 2\n",
    "targ = torch.rand(bs, d)\n",
    "t = torch.randn(bs, nf, seq_len)\n",
    "head = conv_lin_nd_head(nf, c, seq_len, d, conv_first=False, fc_dropout=.5)\n",
    "inp = head(t)\n",
    "test_eq(inp.shape, (bs, d))\n",
    "loss = L1LossFlat()(inp, targ)\n",
    "loss, head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(TensorBase(0.5439, grad_fn=<AliasBackward0>),\n",
       " create_conv_lin_nd_head(\n",
       "   (0): Dropout(p=0.5, inplace=False)\n",
       "   (1): Linear(in_features=10, out_features=6, bias=True)\n",
       "   (2): Conv1d(32, 1, kernel_size=(1,), stride=(1,))\n",
       "   (3): Transpose(-1, -2)\n",
       "   (4): Reshape(bs, 2, 3)\n",
       " ))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 32\n",
    "c = 1\n",
    "seq_len = 10\n",
    "d = [2,3]\n",
    "targ = torch.rand(bs, *d)\n",
    "t = torch.randn(bs, nf, seq_len)\n",
    "head = conv_lin_nd_head(nf, c, seq_len, d, conv_first=False, fc_dropout=.5)\n",
    "inp = head(t)\n",
    "test_eq(inp.shape, [bs]+d)\n",
    "loss = L1LossFlat()(inp, targ)\n",
    "loss, head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class lin_nd_head(nn.Sequential):\n",
    "    \"Module to create a nd output head with linear layers\"\n",
    "\n",
    "    def __init__(self, n_in, n_out, seq_len=None, d=None, flatten=False, use_bn=False, fc_dropout=0.):\n",
    "\n",
    "        if seq_len is None:\n",
    "            seq_len = 1\n",
    "        if d is None:\n",
    "            fd = 1\n",
    "            shape = [n_out]\n",
    "        elif is_listy(d):\n",
    "            fd = 1\n",
    "            shape = []\n",
    "            for _d in d:\n",
    "                fd *= _d\n",
    "                shape.append(_d)\n",
    "            if n_out > 1: shape.append(n_out)\n",
    "        else: \n",
    "            fd = d\n",
    "            shape = [d, n_out] if n_out > 1 else [d]\n",
    "            \n",
    "        layers = []\n",
    "        if use_bn:\n",
    "            layers += [nn.BatchNorm1d(n_in)]\n",
    "        if fc_dropout:\n",
    "            layers += [nn.Dropout(fc_dropout)]\n",
    "        if d is None:\n",
    "            if not flatten or seq_len == 1:\n",
    "                layers += [nn.AdaptiveAvgPool1d(1), Squeeze(-1), nn.Linear(n_in, n_out)]\n",
    "                if n_out == 1:\n",
    "                    layers += [Squeeze(-1)]\n",
    "            else:\n",
    "                layers += [Reshape(), nn.Linear(n_in * seq_len, n_out * fd)]\n",
    "                if n_out * fd== 1:\n",
    "                    layers += [Squeeze(-1)]\n",
    "        else:\n",
    "            if seq_len == 1:\n",
    "                layers += [nn.AdaptiveAvgPool1d(1)]\n",
    "            if not flatten and fd == seq_len:\n",
    "                layers += [Transpose(1,2), nn.Linear(n_in, n_out)]\n",
    "            else:\n",
    "                layers += [Reshape(), nn.Linear(n_in * seq_len, n_out * fd)]\n",
    "            layers += [Reshape(*shape)]\n",
    "\n",
    "        super().__init__(*layers)\n",
    "        \n",
    "create_lin_nd_head = lin_nd_head\n",
    "lin_3d_head = lin_nd_head # included for backwards compatiblity\n",
    "create_lin_3d_head = lin_nd_head # included for backwards compatiblity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "nf = 32\n",
    "seq_len = 50\n",
    "x = torch.normal(0, 1, (bs, nf, seq_len))\n",
    "\n",
    "for use_bn in [False, True]:\n",
    "    for fc_dropout in [0, 0.2]:\n",
    "        for flatten in [False, True]:\n",
    "            for c in [1, 3]:\n",
    "                for d in [None, (50,), (50,10), (30,5), (50,2,3), (30,2,3)]:\n",
    "                    for q_len in [1, seq_len]:\n",
    "                        head = lin_nd_head(nf, c, q_len, d, flatten=flatten, use_bn=use_bn, fc_dropout=fc_dropout)\n",
    "                        test_eq(head(x).shape, (bs, ) + (d if d is not None else ()) + ((c,) if c != 1 else ()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(TensorBase(1.8360, grad_fn=<AliasBackward0>),\n",
       " lin_nd_head(\n",
       "   (0): Dropout(p=0.5, inplace=False)\n",
       "   (1): Reshape(bs)\n",
       "   (2): Linear(in_features=320, out_features=10, bias=True)\n",
       "   (3): Reshape(bs, 2, 5)\n",
       " ))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 32\n",
    "c = 5\n",
    "seq_len = 10\n",
    "d = 2\n",
    "targ = torch.randint(0, c, (bs,d))\n",
    "t = torch.randn(bs, nf, seq_len)\n",
    "head = lin_nd_head(nf, c, seq_len, d, fc_dropout=.5)\n",
    "inp = head(t)\n",
    "test_eq(inp.shape, (bs, d, c))\n",
    "loss = CrossEntropyLossFlat()(inp, targ)\n",
    "loss, head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(TensorBase(1.7557, grad_fn=<AliasBackward0>),\n",
       " lin_nd_head(\n",
       "   (0): Dropout(p=0.5, inplace=False)\n",
       "   (1): Reshape(bs)\n",
       "   (2): Linear(in_features=320, out_features=80, bias=True)\n",
       "   (3): Reshape(bs, 2, 8, 5)\n",
       " ))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 32\n",
    "c = 5\n",
    "seq_len = 10\n",
    "d = [2, 8]\n",
    "targ = torch.randint(0, c, [bs]+d)\n",
    "t = torch.randn(bs, nf, seq_len)\n",
    "head = lin_nd_head(nf, c, seq_len, d, fc_dropout=.5)\n",
    "inp = head(t)\n",
    "test_eq(inp.shape, [bs]+d+[c])\n",
    "loss = CrossEntropyLossFlat()(inp, targ)\n",
    "loss, head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(TensorBase(0.5978, grad_fn=<AliasBackward0>),\n",
       " lin_nd_head(\n",
       "   (0): Dropout(p=0.5, inplace=False)\n",
       "   (1): Reshape(bs)\n",
       "   (2): Linear(in_features=320, out_features=2, bias=True)\n",
       "   (3): Reshape(bs, 2)\n",
       " ))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 32\n",
    "c = 1\n",
    "seq_len = 10\n",
    "d = 2\n",
    "targ = torch.rand(bs, d)\n",
    "t = torch.randn(bs, nf, seq_len)\n",
    "head = lin_nd_head(nf, c, seq_len, d, fc_dropout=.5)\n",
    "inp = head(t)\n",
    "test_eq(inp.shape, (bs, d))\n",
    "loss = L1LossFlat()(inp, targ)\n",
    "loss, head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(TensorBase(0.8286, grad_fn=<AliasBackward0>),\n",
       " lin_nd_head(\n",
       "   (0): Dropout(p=0.5, inplace=False)\n",
       "   (1): Reshape(bs)\n",
       "   (2): Linear(in_features=320, out_features=6, bias=True)\n",
       "   (3): Reshape(bs, 2, 3)\n",
       " ))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 32\n",
    "c = 1\n",
    "seq_len = 10\n",
    "d = [2,3]\n",
    "targ = torch.rand(bs, *d)\n",
    "t = torch.randn(bs, nf, seq_len)\n",
    "head = lin_nd_head(nf, c, seq_len, d, fc_dropout=.5)\n",
    "inp = head(t)\n",
    "test_eq(inp.shape, [bs]+d)\n",
    "loss = L1LossFlat()(inp, targ)\n",
    "loss, head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class rocket_nd_head(nn.Sequential):\n",
    "    \"Module to create a nd output head with linear layers for the rocket family of models\"\n",
    "\n",
    "    def __init__(self, n_in, n_out, seq_len=None, d=None, use_bn=False, fc_dropout=0., zero_init=True):\n",
    "\n",
    "        if d is None:\n",
    "            fd = 1\n",
    "            shape = [n_out]\n",
    "        elif is_listy(d):\n",
    "            fd = 1\n",
    "            shape = []\n",
    "            for _d in d:\n",
    "                fd *= _d\n",
    "                shape.append(_d)\n",
    "            if n_out > 1: shape.append(n_out)\n",
    "        else: \n",
    "            fd = d\n",
    "            shape = [d, n_out] if n_out > 1 else [d]\n",
    "\n",
    "        layers = [nn.Flatten()]\n",
    "        if use_bn:\n",
    "            layers += [nn.BatchNorm1d(n_in)]\n",
    "        if fc_dropout:\n",
    "            layers += [nn.Dropout(fc_dropout)]\n",
    "        linear = nn.Linear(n_in, fd * n_out)\n",
    "        if zero_init:\n",
    "            nn.init.constant_(linear.weight.data, 0)\n",
    "            nn.init.constant_(linear.bias.data, 0)\n",
    "        layers += [linear]\n",
    "        if d is None and n_out == 1:\n",
    "            layers += [Squeeze(-1)]\n",
    "        if d is not None:\n",
    "            layers += [Reshape(*shape)]\n",
    "\n",
    "        super().__init__(*layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "nf = 99\n",
    "seq_len = 1\n",
    "x = torch.normal(0, 1, (bs, nf, seq_len))\n",
    "\n",
    "for use_bn in [False, True]:\n",
    "    for fc_dropout in [0, 0.2]:\n",
    "        for c in [1, 3]:\n",
    "            for d in [None, (50,), (50,10), (30,5), (50,2,3), (30,2,3)]:\n",
    "                head = rocket_nd_head(nf, c, 1, d, use_bn=use_bn, fc_dropout=fc_dropout)\n",
    "                test_eq(head(x).shape, (bs, ) + (d if d is not None else ()) + ((c,) if c != 1 else ()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class xresnet1d_nd_head(nn.Sequential):\n",
    "    \"Module to create a nd output head with linear layers for the xresnet family of models\"\n",
    "\n",
    "    def __init__(self, n_in, n_out, seq_len=None, d=None, use_bn=False, fc_dropout=0., zero_init=True):\n",
    "\n",
    "        if d is None:\n",
    "            fd = 1\n",
    "            shape = [n_out]\n",
    "        elif is_listy(d):\n",
    "            fd = 1\n",
    "            shape = []\n",
    "            for _d in d:\n",
    "                fd *= _d\n",
    "                shape.append(_d)\n",
    "            if n_out > 1: shape.append(n_out)\n",
    "        else: \n",
    "            fd = d\n",
    "            shape = [d, n_out] if n_out > 1 else [d]\n",
    "\n",
    "        layers = [nn.AdaptiveAvgPool1d(1), nn.Flatten()]\n",
    "        if use_bn:\n",
    "            layers += [nn.BatchNorm1d(n_in)]\n",
    "        if fc_dropout:\n",
    "            layers += [nn.Dropout(fc_dropout)]\n",
    "        linear = nn.Linear(n_in, fd * n_out)\n",
    "        if zero_init:\n",
    "            nn.init.constant_(linear.weight.data, 0)\n",
    "            nn.init.constant_(linear.bias.data, 0)\n",
    "        layers += [linear]\n",
    "        if d is None and n_out == 1:\n",
    "            layers += [Squeeze(-1)]\n",
    "        if d is not None:\n",
    "            layers += [Reshape(*shape)]\n",
    "\n",
    "        super().__init__(*layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "nf = 99\n",
    "seq_len = 2\n",
    "x = torch.normal(0, 1, (bs, nf, seq_len))\n",
    "\n",
    "for use_bn in [False, True]:\n",
    "    for fc_dropout in [0, 0.2]:\n",
    "        for c in [1, 3]:\n",
    "            for d in [None, (50,), (50,10), (30,5), (50,2,3), (30,2,3)]:\n",
    "                head = xresnet1d_nd_head(nf, c, 1, d, use_bn=use_bn, fc_dropout=fc_dropout)\n",
    "                test_eq(head(x).shape, (bs, ) + (d if d is not None else ()) + ((c,) if c != 1 else ()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class create_conv_3d_head(nn.Sequential):\n",
    "    \"Module to create a nd output head with a convolutional layer\"\n",
    "    def __init__(self, n_in, n_out, seq_len, d, use_bn=False, **kwargs):\n",
    "        assert d, \"you cannot use an 3d head when d is None or 0\"\n",
    "        assert d == seq_len, 'You can only use this head when learn.dls.len == learn.dls.d'\n",
    "        layers = [nn.BatchNorm1d(n_in)] if use_bn else []\n",
    "        layers += [Conv(n_in, n_out, 1, **kwargs), Transpose(-1,-2)]\n",
    "        if n_out == 1: layers += [Squeeze(-1)]\n",
    "        super().__init__(*layers)\n",
    "        \n",
    "conv_3d_head = create_conv_3d_head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(TensorBase(1.7321, grad_fn=<AliasBackward0>),\n",
       " create_conv_3d_head(\n",
       "   (0): ConvBlock(\n",
       "     (0): Conv1d(32, 5, kernel_size=(1,), stride=(1,))\n",
       "   )\n",
       "   (1): Transpose(-1, -2)\n",
       " ))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 32\n",
    "c = 5\n",
    "seq_len = 10\n",
    "d = 10\n",
    "targ = torch.randint(0, c, (bs,d))\n",
    "t = torch.randn(bs, nf, seq_len)\n",
    "head = conv_3d_head(nf, c, seq_len, d)\n",
    "inp = head(t)\n",
    "test_eq(inp.shape, (bs, d, c))\n",
    "loss = CrossEntropyLossFlat()(inp, targ)\n",
    "loss, head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(TensorBase(0.5833, grad_fn=<AliasBackward0>),\n",
       " create_conv_3d_head(\n",
       "   (0): ConvBlock(\n",
       "     (0): Conv1d(32, 1, kernel_size=(1,), stride=(1,))\n",
       "   )\n",
       "   (1): Transpose(-1, -2)\n",
       "   (2): Squeeze(dim=-1)\n",
       " ))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nf = 32\n",
    "c = 1\n",
    "seq_len = 10\n",
    "d = 10\n",
    "targ = torch.rand(bs, d)\n",
    "t = torch.randn(bs, nf, seq_len)\n",
    "head = conv_3d_head(nf, c, seq_len, d)\n",
    "inp = head(t)\n",
    "test_eq(inp.shape, (bs, d))\n",
    "loss = L1LossFlat()(inp, targ)\n",
    "loss, head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def universal_pool_head(n_in, c_out, seq_len, mult=2, pool_n_layers=2, pool_ln=True, pool_dropout=0.5, pool_act=nn.ReLU(),\n",
    "                        zero_init=True, bn=True, fc_dropout=0.):\n",
    "    return nn.Sequential(AdaptiveWeightedAvgPool1d(n_in, seq_len, n_layers=pool_n_layers, mult=mult, ln=pool_ln, dropout=pool_dropout, act=pool_act), \n",
    "                         Reshape(), LinBnDrop(n_in, c_out, p=fc_dropout, bn=bn))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs, c_in, seq_len = 16, 128, 50\n",
    "c_out = 14\n",
    "t = torch.rand(bs, c_in, seq_len)\n",
    "uph = universal_pool_head(c_in, c_out, seq_len)\n",
    "test_eq(uph(t).shape, (bs, c_out))\n",
    "uph = universal_pool_head(c_in, c_out, seq_len, 2)\n",
    "test_eq(uph(t).shape, (bs, c_out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "heads = [mlp_head, fc_head, average_pool_head, max_pool_head, concat_pool_head, pool_plus_head, conv_head, rnn_head, \n",
    "         conv_lin_nd_head, lin_nd_head, conv_3d_head, attentional_pool_head, universal_pool_head, gwa_pool_head]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "create_mlp_head\n",
      "create_fc_head\n",
      "average_pool_head\n",
      "max_pool_head\n",
      "concat_pool_head\n",
      "create_pool_plus_head\n",
      "create_conv_head\n",
      "create_rnn_head\n",
      "create_conv_lin_nd_head\n",
      "lin_nd_head\n",
      "create_conv_3d_head\n",
      "attentional_pool_head\n",
      "universal_pool_head\n",
      "gwa_pool_head\n"
     ]
    }
   ],
   "source": [
    "bs, c_in, seq_len = 16, 128, 50\n",
    "c_out = 14\n",
    "d = 5\n",
    "t = torch.rand(bs, c_in, seq_len)\n",
    "for head in heads: \n",
    "    print(head.__name__)\n",
    "    if head.__name__ == \"create_conv_3d_head\":\n",
    "        h = head(c_in, c_out, seq_len, seq_len)\n",
    "        test_eq(h(t).shape, (bs, seq_len, c_out))\n",
    "    elif 'nd' in head.__name__: \n",
    "        h = head(c_in, c_out, seq_len, d)\n",
    "        test_eq(h(t).shape, (bs, d, c_out))\n",
    "    else: \n",
    "        h = head(c_in, c_out, seq_len)\n",
    "        test_eq(h(t).shape, (bs, c_out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class SqueezeExciteBlock(Module):\n",
    "    def __init__(self, ni, reduction=16):\n",
    "        self.avg_pool = GAP1d(1)\n",
    "        self.fc = nn.Sequential(nn.Linear(ni, ni // reduction, bias=False), nn.ReLU(),  nn.Linear(ni // reduction, ni, bias=False), nn.Sigmoid())\n",
    "\n",
    "    def forward(self, x):\n",
    "        y = self.avg_pool(x)\n",
    "        y = self.fc(y).unsqueeze(2)\n",
    "        return x * y.expand_as(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 2\n",
    "ni = 32\n",
    "sl = 4\n",
    "t = torch.rand(bs, ni, sl)\n",
    "test_eq(SqueezeExciteBlock(ni)(t).shape, (bs, ni, sl))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class GaussianNoise(Module):\n",
    "    \"\"\"Gaussian noise regularizer.\n",
    "\n",
    "    Args:\n",
    "        sigma (float, optional): relative standard deviation used to generate the\n",
    "            noise. Relative means that it will be multiplied by the magnitude of\n",
    "            the value your are adding the noise to. This means that sigma can be\n",
    "            the same regardless of the scale of the vector.\n",
    "        is_relative_detach (bool, optional): whether to detach the variable before\n",
    "            computing the scale of the noise. If `False` then the scale of the noise\n",
    "            won't be seen as a constant but something to optimize: this will bias the\n",
    "            network to generate vectors with smaller values.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, sigma=0.1, is_relative_detach=True):\n",
    "        self.sigma, self.is_relative_detach = sigma, is_relative_detach\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.training and self.sigma not in [0, None]:\n",
    "            scale = self.sigma * (x.detach() if self.is_relative_detach else x)\n",
    "            sampled_noise = torch.empty(x.size(), device=x.device).normal_() * scale\n",
    "            x = x + sampled_noise\n",
    "        return x "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.ones(2,3,4)\n",
    "test_ne(GaussianNoise()(t), t)\n",
    "test_eq(GaussianNoise()(t).shape, t.shape)\n",
    "t = torch.ones(2,3)\n",
    "test_ne(GaussianNoise()(t), t)\n",
    "test_eq(GaussianNoise()(t).shape, t.shape)\n",
    "t = torch.ones(2)\n",
    "test_ne(GaussianNoise()(t), t)\n",
    "test_eq(GaussianNoise()(t).shape, t.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "import numpy as np\n",
    "from scipy.stats import ttest_ind, ttest_ind_from_stats\n",
    "from scipy.special import stdtr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ttest_ind:            t = -1.5827  p = 0.118873\n",
      "ttest_ind_from_stats: t = -1.5827  p = 0.118873\n",
      "formula:              t = -1.5827  p = 0.118873\n",
      "formula:              t = -1.5827\n"
     ]
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "# https://stackoverflow.com/questions/22611446/perform-2-sample-t-test\n",
    "np.random.seed(1)\n",
    "\n",
    "# Create sample data.\n",
    "a = np.random.randn(40)\n",
    "b = 4*np.random.randn(50)\n",
    "\n",
    "# Use scipy.stats.ttest_ind.\n",
    "t, p = ttest_ind(a, b, equal_var=False)\n",
    "print(\"ttest_ind:            t = %g  p = %g\" % (t, p))\n",
    "\n",
    "# Compute the descriptive statistics of a and b.\n",
    "abar = a.mean()\n",
    "avar = a.var(ddof=1)\n",
    "na = a.size\n",
    "adof = na - 1\n",
    "\n",
    "bbar = b.mean()\n",
    "bvar = b.var(ddof=1)\n",
    "nb = b.size\n",
    "bdof = nb - 1\n",
    "\n",
    "# Use scipy.stats.ttest_ind_from_stats.\n",
    "t2, p2 = ttest_ind_from_stats(abar, np.sqrt(avar), na,\n",
    "                              bbar, np.sqrt(bvar), nb,\n",
    "                              equal_var=False)\n",
    "print(\"ttest_ind_from_stats: t = %g  p = %g\" % (t2, p2))\n",
    "\n",
    "# Use the formulas directly.\n",
    "tf = (abar - bbar) / np.sqrt(avar/na + bvar/nb)\n",
    "dof = (avar/na + bvar/nb)**2 / (avar**2/(na**2*adof) + bvar**2/(nb**2*bdof))\n",
    "pf = 2*stdtr(dof, -np.abs(tf))\n",
    "\n",
    "print(\"formula:              t = %g  p = %g\" % (tf, pf))\n",
    "\n",
    "a = torch.tensor(a)\n",
    "b = torch.tensor(b)\n",
    "tf = (a.mean() - b.mean()) / torch.sqrt(a.var()/a.size(0) + b.var()/b.size(0))\n",
    "print(\"formula:              t = %g\" % (tf))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class PositionwiseFeedForward(nn.Sequential):\n",
    "    def __init__(self, dim, dropout=0., act='reglu', mlp_ratio=1):\n",
    "        act_mult = 2 if act.lower() in [\"geglu\", \"reglu\"] else 1\n",
    "        super().__init__(nn.Linear(dim, dim * mlp_ratio * act_mult),\n",
    "                         get_act_fn(act),\n",
    "                         nn.Dropout(dropout),\n",
    "                         nn.Linear(dim * mlp_ratio, dim),\n",
    "                         nn.Dropout(dropout))\n",
    "\n",
    "class TokenLayer(Module):\n",
    "    def __init__(self, token=True): self.token = token\n",
    "    def forward(self, x): return x[..., 0] if self.token is not None else x.mean(-1)\n",
    "    def __repr__(self): return f\"{self.__class__.__name__}()\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.randn(2,3,10)\n",
    "m = PositionwiseFeedForward(10, dropout=0., act='reglu', mlp_ratio=1)\n",
    "test_eq(m(t).shape, t.shape)\n",
    "m = PositionwiseFeedForward(10, dropout=0., act='smelu', mlp_ratio=1)\n",
    "test_eq(m(t).shape, t.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class ScaledDotProductAttention(Module):\n",
    "    r\"\"\"Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer \n",
    "    (Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets \n",
    "    by Lee et al, 2021)\"\"\"\n",
    "\n",
    "    def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=False, lsa=False):  \n",
    "        self.attn_dropout = nn.Dropout(attn_dropout)\n",
    "        self.res_attention = res_attention\n",
    "        head_dim = d_model // n_heads\n",
    "        self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa)\n",
    "        self.lsa = lsa\n",
    "\n",
    "    def forward(self, q:Tensor, k:Tensor, v:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):\n",
    "        '''\n",
    "        Input shape:\n",
    "            q               : [bs x n_heads x max_q_len x d_k]\n",
    "            k               : [bs x n_heads x d_k x seq_len]\n",
    "            v               : [bs x n_heads x seq_len x d_v]\n",
    "            prev            : [bs x n_heads x q_len x seq_len]\n",
    "            key_padding_mask: [bs x seq_len]\n",
    "            attn_mask       : [1 x seq_len x seq_len]\n",
    "\n",
    "        Output shape: \n",
    "            output:  [bs x n_heads x q_len x d_v]\n",
    "            attn   : [bs x n_heads x q_len x seq_len]\n",
    "            scores : [bs x n_heads x q_len x seq_len]\n",
    "        '''\n",
    "\n",
    "        # Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence\n",
    "        attn_scores = torch.matmul(q, k) * self.scale      # attn_scores : [bs x n_heads x max_q_len x q_len]\n",
    "\n",
    "        # Add pre-softmax attention scores from the previous layer (optional)\n",
    "        if prev is not None: attn_scores = attn_scores + prev \n",
    "\n",
    "        # Attention mask (optional)\n",
    "        if attn_mask is not None:                                     # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len\n",
    "            if attn_mask.dtype == torch.bool:\n",
    "                attn_scores.masked_fill_(attn_mask, -np.inf)\n",
    "            else:\n",
    "                attn_scores += attn_mask\n",
    "\n",
    "        # Key padding mask (optional)\n",
    "        if key_padding_mask is not None:                              # mask with shape [bs x q_len] (only when max_w_len == q_len)\n",
    "            attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf)\n",
    "\n",
    "        # normalize the attention weights\n",
    "        attn_weights = F.softmax(attn_scores, dim=-1)                 # attn_weights   : [bs x n_heads x max_q_len x q_len]\n",
    "        attn_weights = self.attn_dropout(attn_weights)\n",
    "\n",
    "        # compute the new values given the attention weights\n",
    "        output = torch.matmul(attn_weights, v)                        # output: [bs x n_heads x max_q_len x d_v]\n",
    "\n",
    "        if self.res_attention: return output, attn_weights, attn_scores\n",
    "        else: return output, attn_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(1.3535e-10, grad_fn=<MeanBackward0>),\n",
       " tensor(1.0555, grad_fn=<StdBackward0>))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "B = 16\n",
    "C = 10\n",
    "M = 1500 # seq_len\n",
    "\n",
    "n_heads = 1\n",
    "D = 128 # model dimension\n",
    "N = 512 # max_seq_len - latent's index dimension\n",
    "d_k = D // n_heads\n",
    "\n",
    "xb = torch.randn(B, C, M)\n",
    "xb = (xb - xb.mean()) / xb.std()\n",
    "\n",
    "# Attention\n",
    "# input (Q)\n",
    "lin = nn.Linear(M, N, bias=False)\n",
    "Q = lin(xb).transpose(1,2)\n",
    "test_eq(Q.shape, (B, N, C))\n",
    "\n",
    "# q\n",
    "to_q = nn.Linear(C, D, bias=False)\n",
    "q = to_q(Q)\n",
    "q = nn.LayerNorm(D)(q)\n",
    "\n",
    "# k, v\n",
    "context = xb.transpose(1,2)\n",
    "to_kv = nn.Linear(C, D * 2, bias=False)\n",
    "k, v = to_kv(context).chunk(2, dim = -1)\n",
    "k = k.transpose(-1, -2)\n",
    "k = nn.LayerNorm(M)(k)\n",
    "v = nn.LayerNorm(D)(v)\n",
    "\n",
    "test_eq(q.shape, (B, N, D))\n",
    "test_eq(k.shape, (B, D, M))\n",
    "test_eq(v.shape, (B, M, D))\n",
    "\n",
    "output, attn, scores = ScaledDotProductAttention(D, n_heads, res_attention=True)(q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1))\n",
    "test_eq(output.shape, (B, 1, N, D))\n",
    "test_eq(attn.shape, (B, 1, N, M))\n",
    "test_eq(scores.shape, (B, 1, N, M))\n",
    "scores.mean(), scores.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class MultiheadAttention(Module):\n",
    "    def __init__(self, d_model, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True, lsa=False):\n",
    "        \"\"\"Multi Head Attention Layer\n",
    "\n",
    "        Input shape:\n",
    "            Q:       [batch_size (bs) x max_q_len x d_model]\n",
    "            K, V:    [batch_size (bs) x q_len x d_model]\n",
    "            mask:    [q_len x q_len]\n",
    "        \"\"\"\n",
    "\n",
    "        d_k = ifnone(d_k, d_model // n_heads)\n",
    "        d_v = ifnone(d_v, d_model // n_heads)\n",
    "\n",
    "        self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v\n",
    "\n",
    "        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)\n",
    "        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)\n",
    "        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias)\n",
    "\n",
    "        # Scaled Dot-Product Attention (multiple heads)\n",
    "        self.res_attention = res_attention\n",
    "        self.sdp_attn = ScaledDotProductAttention(d_model, n_heads, attn_dropout=attn_dropout, res_attention=self.res_attention, lsa=lsa)\n",
    "\n",
    "        # Poject output\n",
    "        self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout))\n",
    "\n",
    "\n",
    "    def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor]=None, prev:Optional[Tensor]=None,\n",
    "                key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):\n",
    "\n",
    "        bs = Q.size(0)\n",
    "        if K is None: K = Q\n",
    "        if V is None: V = Q\n",
    "\n",
    "        # Linear (+ split in multiple heads)\n",
    "        q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2)       # q_s    : [bs x n_heads x max_q_len x d_k]\n",
    "        k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1)     # k_s    : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)\n",
    "        v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2)       # v_s    : [bs x n_heads x q_len x d_v]\n",
    "\n",
    "        # Apply Scaled Dot-Product Attention (multiple heads)\n",
    "        if self.res_attention:\n",
    "            output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "        else:\n",
    "            output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "        # output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len]\n",
    "\n",
    "        # back to the original inputs dimensions\n",
    "        output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v]\n",
    "        output = self.to_out(output)\n",
    "\n",
    "        if self.res_attention: return output, attn_weights, attn_scores\n",
    "        else: return output, attn_weights "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "attn_mask torch.Size([50, 50]) key_padding_mask torch.Size([16, 50])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(torch.Size([16, 3, 50, 6]), torch.Size([16, 3, 50, 50]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "q = torch.rand([16, 3, 50, 8]) \n",
    "k = torch.rand([16, 3, 50, 8]).transpose(-1, -2)\n",
    "v = torch.rand([16, 3, 50, 6])\n",
    "attn_mask = torch.triu(torch.ones(50, 50)) # shape: q_len x q_len\n",
    "key_padding_mask = torch.zeros(16, 50)\n",
    "key_padding_mask[[1, 3, 6, 15], -10:] = 1\n",
    "key_padding_mask = key_padding_mask.bool()\n",
    "print('attn_mask', attn_mask.shape, 'key_padding_mask', key_padding_mask.shape)\n",
    "output, attn = ScaledDotProductAttention(24, 3, attn_dropout=.1)(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)\n",
    "output.shape, attn.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([16, 50, 128]), torch.Size([16, 3, 50, 50]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = torch.rand(16, 50, 128)\n",
    "output, attn = MultiheadAttention(d_model=128, n_heads=3, d_k=8, d_v=6)(t, t, t, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "output.shape, attn.shape"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test multi-head attention with self-locality attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([16, 50, 128]), torch.Size([16, 8, 50, 50]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# lsa (locality self-sttention)\n",
    "t = torch.rand(16, 50, 128)\n",
    "attn_mask = torch.eye(50).reshape(1, 1, 50, 50).bool()\n",
    "output, attn = MultiheadAttention(d_model=128, n_heads=8, lsa=True)(t, t, t, key_padding_mask=key_padding_mask, attn_mask=attn_mask)\n",
    "output.shape, attn.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.rand(16, 50, 128)\n",
    "att_mask = (torch.rand((50, 50)) > .85).float()\n",
    "att_mask[att_mask == 1] = -np.inf\n",
    "\n",
    "mha = MultiheadAttention(d_model=128, n_heads=3, d_k=8, d_v=6)\n",
    "output, attn = mha(t, t, t, attn_mask=att_mask)\n",
    "test_eq(torch.isnan(output).sum().item(), 0)\n",
    "test_eq(torch.isnan(attn).sum().item(), 0)\n",
    "loss = output[:2, :].sum()\n",
    "test_eq(torch.isnan(loss).sum().item(), 0)\n",
    "loss.backward()\n",
    "for n, p in mha.named_parameters(): \n",
    "    if p.grad is not None:\n",
    "        test_eq(torch.isnan(p.grad).sum().item(), 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.rand(16, 50, 128)\n",
    "attn_mask = (torch.rand((50, 50)) > .85)\n",
    "\n",
    "# True values will be masked\n",
    "mha = MultiheadAttention(d_model=128, n_heads=3, d_k=8, d_v=6)\n",
    "output, attn = mha(t, t, t, attn_mask=att_mask)\n",
    "test_eq(torch.isnan(output).sum().item(), 0)\n",
    "test_eq(torch.isnan(attn).sum().item(), 0)\n",
    "loss = output[:2, :].sum()\n",
    "test_eq(torch.isnan(loss).sum().item(), 0)\n",
    "loss.backward()\n",
    "for n, p in mha.named_parameters(): \n",
    "    if p.grad is not None:\n",
    "        test_eq(torch.isnan(p.grad).sum().item(), 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class MultiConv1d(Module):\n",
    "    \"\"\"Module that applies multiple convolutions with different kernel sizes\"\"\"\n",
    "\n",
    "    def __init__(self, ni, nf=None, kss=[1,3,5,7], keep_original=False, separable=False, dim=1, **kwargs):\n",
    "        kss = listify(kss)\n",
    "        n_layers = len(kss)\n",
    "        if ni == nf: keep_original = False\n",
    "        if nf is None: nf = ni * (keep_original + n_layers)\n",
    "        nfs = [(nf - ni*keep_original) // n_layers] * n_layers\n",
    "        while np.sum(nfs) + ni * keep_original < nf:\n",
    "            for i in range(len(nfs)):\n",
    "                nfs[i] += 1\n",
    "                if np.sum(nfs) + ni * keep_original == nf: break\n",
    "\n",
    "        _conv = SeparableConv1d if separable else Conv1d\n",
    "        self.layers = nn.ModuleList()\n",
    "        for nfi,ksi in zip(nfs, kss):\n",
    "            self.layers.append(_conv(ni, nfi, ksi, **kwargs))\n",
    "        self.keep_original, self.dim = keep_original, dim\n",
    "\n",
    "    def forward(self, x):\n",
    "        output = [x] if self.keep_original else []\n",
    "        for l in self.layers:\n",
    "            output.append(l(x))\n",
    "        x = torch.cat(output, dim=self.dim)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.rand(16, 6, 37)\n",
    "test_eq(MultiConv1d(6, None, kss=[1,3,5], keep_original=True)(t).shape, [16, 24, 37])\n",
    "test_eq(MultiConv1d(6, 36, kss=[1,3,5], keep_original=False)(t).shape, [16, 36, 37])\n",
    "test_eq(MultiConv1d(6, None, kss=[1,3,5], keep_original=True, dim=-1)(t).shape, [16, 6, 37*4])\n",
    "test_eq(MultiConv1d(6, 60, kss=[1,3,5], keep_original=True)(t).shape, [16, 60, 37])\n",
    "test_eq(MultiConv1d(6, 60, kss=[1,3,5], separable=True)(t).shape, [16, 60, 37])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class LSTMOutput(Module):\n",
    "    def forward(self, x): return x[0]\n",
    "    def __repr__(self): return f'{self.__class__.__name__}()'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = ([1], [2], [3])\n",
    "test_eq(LSTMOutput()(t), [1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def emb_sz_rule(n_cat):\n",
    "    \"Rule of thumb to pick embedding size corresponding to `n_cat` (original from fastai)\"\n",
    "    return min(600, round(1.6 * n_cat**0.56))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(emb_sz_rule(7), 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSEmbedding(nn.Embedding):\n",
    "    \"Embedding layer with truncated normal initialization adapted from fastai\"\n",
    "    def __init__(self, ni, nf, std=0.01, padding_idx=None):\n",
    "        super().__init__(ni, nf)\n",
    "        trunc_normal_(self.weight.data, std=std)\n",
    "        if padding_idx is not None:\n",
    "            nn.init.zeros_(self.weight.data[padding_idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class MultiEmbedding(Module):\n",
    "    def __init__(self, c_in, n_cat_embeds, cat_embed_dims=None, cat_pos=None, std=0.01, cat_padding_idxs=None):\n",
    "        cat_n_embeds = listify(n_cat_embeds)\n",
    "        if cat_padding_idxs is None: cat_padding_idxs = [None]\n",
    "        else: cat_padding_idxs = listify(cat_padding_idxs)\n",
    "        if len(cat_padding_idxs) == 1 and len(cat_padding_idxs) < len(cat_n_embeds): \n",
    "            cat_padding_idxs = cat_padding_idxs * len(cat_n_embeds)\n",
    "        assert len(cat_n_embeds) == len(cat_padding_idxs)\n",
    "        if cat_embed_dims is None: \n",
    "            cat_embed_dims = [emb_sz_rule(s) for s in cat_n_embeds]\n",
    "        else:\n",
    "            cat_embed_dims = listify(cat_embed_dims)\n",
    "            if len(cat_embed_dims) == 1: cat_embed_dims = cat_embed_dims * len(cat_n_embeds)\n",
    "            assert len(cat_embed_dims) == len(cat_n_embeds)\n",
    "        if cat_pos: \n",
    "            cat_pos = torch.as_tensor(listify(cat_pos))  \n",
    "        else: \n",
    "            cat_pos = torch.arange(len(cat_n_embeds))\n",
    "        self.register_buffer(\"cat_pos\", cat_pos)\n",
    "        cont_pos = torch.tensor([p for p in torch.arange(c_in) if p not in self.cat_pos])\n",
    "        self.register_buffer(\"cont_pos\", cont_pos)\n",
    "        self.cat_embed = nn.ModuleList([TSEmbedding(n,d,std=std, padding_idx=p) for n,d,p in zip(cat_n_embeds, cat_embed_dims, cat_padding_idxs)])\n",
    "\n",
    "    def forward(self, x):\n",
    "        if isinstance(x, tuple): x_cat, x_cont, *_ = x\n",
    "        else: x_cat, x_cont = x[:, self.cat_pos], x[:, self.cont_pos]\n",
    "        x_cat = torch.cat([e(torch.round(x_cat[:,i]).long()).transpose(1,2) for i,e in enumerate(self.cat_embed)],1)\n",
    "        return torch.cat([x_cat, x_cont], 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[3, 4] [3, 3] torch.Size([4, 3, 10]) torch.Size([4, 7, 10])\n"
     ]
    }
   ],
   "source": [
    "a = alphabet[np.random.randint(0,3,40)]\n",
    "b = ALPHABET[np.random.randint(6,10,40)]\n",
    "c = np.random.rand(40).reshape(4,1,10)\n",
    "map_a = {k:v for v,k in enumerate(np.unique(a))}\n",
    "map_b = {k:v for v,k in enumerate(np.unique(b))}\n",
    "n_embeds = [len(m.keys()) for m in [map_a, map_b]]\n",
    "szs = [emb_sz_rule(n) for n in n_embeds]\n",
    "a = np.asarray(a.map(map_a)).reshape(4,1,10)\n",
    "b = np.asarray(b.map(map_b)).reshape(4,1,10)\n",
    "inp = torch.from_numpy(np.concatenate((c,a,b), 1)).float()\n",
    "memb = MultiEmbedding(3, n_embeds, cat_pos=[1,2])\n",
    "# registered buffers are part of the state_dict() but not module.parameters()\n",
    "assert all([(k in memb.state_dict().keys()) for k in ['cat_pos', 'cont_pos']])\n",
    "embeddings = memb(inp)\n",
    "print(n_embeds, szs, inp.shape, embeddings.shape)\n",
    "test_eq(embeddings.shape, (inp.shape[0],sum(szs)+1,inp.shape[-1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "me = MultiEmbedding(3, 4, cat_pos=2)\n",
    "test_eq(me.cat_embed[0].weight.shape, (4,3))\n",
    "test_eq(me.cat_pos.cpu().item(), 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": "IPython.notebook.save_checkpoint();",
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/nacho/notebooks/tsai/nbs/029_models.layers.ipynb couldn't be saved automatically. You should save it manually 👋\n",
      "Correct notebook to script conversion! 😃\n",
      "Thursday 08/06/23 19:24:22 CEST\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "from tsai.export import get_nb_name; nb_name = get_nb_name(locals())\n",
    "from tsai.imports import create_scripts; create_scripts(nb_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
